AI-RESEARCHER-2024 commited on
Commit
ab5da21
·
verified ·
1 Parent(s): da92822

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -50
app.py CHANGED
@@ -1,50 +1,92 @@
1
- import gradio as gr
2
- import tensorflow as tf
3
- import numpy as np
4
-
5
- # Load the trained model
6
- model = tf.keras.models.load_model('model.h5')
7
- print("Model loaded successfully!")
8
-
9
- def preprocess_image(image):
10
- """Process the input image to match MNIST format"""
11
- # Convert to grayscale
12
- image = image.convert('L')
13
- # Resize to 28x28
14
- image = image.resize((28, 28))
15
- # Convert to numpy array and normalize
16
- image_array = np.array(image)
17
- image_array = image_array / 255.0
18
- # Reshape to match model input
19
- image_array = np.expand_dims(image_array, axis=0)
20
- return image_array
21
-
22
- def predict_digit(image):
23
- if image is None:
24
- return None
25
-
26
- # Preprocess the image
27
- processed_image = preprocess_image(image)
28
-
29
- # Make prediction
30
- predictions = model.predict(processed_image)
31
- pred_scores = tf.nn.softmax(predictions[0]).numpy()
32
- pred_class = np.argmax(pred_scores)
33
-
34
- # Create result string
35
- result = f"Prediction: {pred_class}"
36
-
37
- return result
38
-
39
- # Create Gradio interface
40
- demo = gr.Interface(
41
- fn=predict_digit,
42
- inputs=gr.Image(type="pil"),
43
- outputs=gr.Textbox(label="Result"),
44
- title="MNIST Digit Recognizer",
45
- description="Upload a digit from 0-9 and the model will predict which digit it is.",
46
- examples=None,
47
- )
48
-
49
- if __name__ == "__main__":
50
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PyPDF2 import PdfReader
3
+ from io import BytesIO
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain.chains.question_answering import load_qa_chain
9
+ from langchain.prompts import PromptTemplate
10
+
11
+ # Hardcoded Gemini API keys
12
+ API_KEYS = [
13
+ "AIzaSyBYbyC4qCJoKxKR-r0oIn4SVqj4CfSdx4s",
14
+ "AIzaSyBfvYURYVijTurxvFUyV3vZkCojpOAAnFk"
15
+ ]
16
+ current_key_index = -1
17
+
18
+ def switch_api_key():
19
+ global current_key_index
20
+ current_key_index = (current_key_index + 1) % len(API_KEYS)
21
+ return API_KEYS[current_key_index]
22
+
23
+ def get_pdf_text(pdf_docs):
24
+ text = ""
25
+ for pdf in pdf_docs:
26
+ pdf_reader = PdfReader(BytesIO(pdf.read()))
27
+ for page in pdf_reader.pages:
28
+ page_text = page.extract_text()
29
+ if page_text:
30
+ text += page_text
31
+ return text
32
+
33
+ def get_text_chunks(text):
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
35
+ return text_splitter.split_text(text)
36
+
37
+ def get_vector_store(text_chunks, api_key):
38
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
39
+ vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
40
+ # Use /tmp for Hugging Face Spaces
41
+ vector_store.save_local("/tmp/faiss_index")
42
+
43
+ def get_conversational_chain(api_key):
44
+ prompt_template = """
45
+ You are a helpful assistant that only answers based on the context provided from the PDF documents.
46
+ Do not use any external knowledge or assumptions. If the answer is not found in the context below, reply with "I don't know."
47
+ Context:
48
+ {context}
49
+ Question:
50
+ {question}
51
+ Answer:
52
+ """
53
+ model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, google_api_key=api_key)
54
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
55
+ chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
56
+ return chain
57
+
58
+ def user_input(user_question, api_key):
59
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
60
+ # Use /tmp for Hugging Face Spaces
61
+ new_db = FAISS.load_local("/tmp/faiss_index", embeddings, allow_dangerous_deserialization=True)
62
+ docs = new_db.similarity_search(user_question)
63
+ chain = get_conversational_chain(api_key)
64
+ response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
65
+ st.write("Reply: ", response["output_text"])
66
+
67
+ def main():
68
+ st.set_page_config(page_title="Chat PDF")
69
+ st.header("CSC 121: Computers and Scientific Thinking (Chatbot)")
70
+ st.subheader("Ask a question ONLY from the CSC 121 textbook of Dr. Reed", divider=True)
71
+
72
+ user_api_key = st.text_input("Enter your API key (optional)", type="password")
73
+ user_question = st.text_input("Ask a question")
74
+
75
+ if user_question:
76
+ api_key = user_api_key if user_api_key else switch_api_key()
77
+ user_input(user_question, api_key)
78
+
79
+ pdf_docs = st.file_uploader("Upload PDF files", accept_multiple_files=True, type=['pdf'])
80
+ if st.button("Submit & Process"):
81
+ if pdf_docs:
82
+ with st.spinner("Processing..."):
83
+ api_key = user_api_key if user_api_key else switch_api_key()
84
+ raw_text = get_pdf_text(pdf_docs)
85
+ text_chunks = get_text_chunks(raw_text)
86
+ get_vector_store(text_chunks, api_key)
87
+ st.success("Done")
88
+ else:
89
+ st.error("Please upload at least one PDF file.")
90
+
91
+ if __name__ == "__main__":
92
+ main()