RAG_PDF / app.py
Shreyas094's picture
Update app.py
4c57776 verified
import os
import json
import logging
import shutil
import gradio as gr
from typing import List
from tempfile import NamedTemporaryFile
from huggingface_hub import InferenceClient
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.docstore.document import Document
# Setup logging
logging.basicConfig(level=logging.INFO)
# Constants
DOCUMENTS_FILE = "uploaded_documents.json"
DEFAULT_MODEL = "@cf/meta/llama-2-7b-chat"
HF_TOKEN = os.getenv("HF_API_TOKEN") # Make sure to set this environment variable
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
def get_embeddings():
return HuggingFaceEmbeddings(
model_name=EMBED_MODEL,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
def load_documents():
if os.path.exists(DOCUMENTS_FILE):
with open(DOCUMENTS_FILE, "r") as f:
return json.load(f)
return []
def save_documents(documents):
with open(DOCUMENTS_FILE, "w") as f:
json.dump(documents, f)
def load_document(file: NamedTemporaryFile) -> List[Document]:
"""Loads and splits the document into pages using PyPDF."""
loader = PyPDFLoader(file.name)
return loader.load_and_split()
def process_uploaded_files(files):
if not files:
return "Please upload at least one file.", []
files_list = [files] if not isinstance(files, list) else files
embed = get_embeddings()
uploaded_documents = load_documents()
total_chunks = 0
all_data = []
for file in files_list:
try:
data = load_document(file)
if not data:
continue
all_data.extend(data)
total_chunks += len(data)
if not any(doc["name"] == file.name for doc in uploaded_documents):
uploaded_documents.append({"name": file.name, "selected": True})
except Exception as e:
logging.error(f"Error processing file {file.name}: {str(e)}")
if not all_data:
return "No valid data could be extracted from the uploaded files.", []
try:
if os.path.exists("faiss_database"):
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
database.add_documents(all_data)
else:
database = FAISS.from_documents(all_data, embed)
database.save_local("faiss_database")
save_documents(uploaded_documents)
return f"Vector store updated successfully. Processed {total_chunks} chunks.", [doc["name"] for doc in uploaded_documents]
except Exception as e:
return f"Error updating vector store: {str(e)}", []
def delete_documents(selected_docs):
if not selected_docs:
return "No documents selected for deletion.", []
uploaded_documents = load_documents()
embed = get_embeddings()
if os.path.exists("faiss_database"):
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
docs_to_keep = []
for doc in database.docstore._dict.values():
if doc.metadata.get("source") not in selected_docs:
docs_to_keep.append(doc)
if not docs_to_keep:
shutil.rmtree("faiss_database")
else:
new_database = FAISS.from_documents(docs_to_keep, embed)
new_database.save_local("faiss_database")
uploaded_documents = [doc for doc in uploaded_documents if doc["name"] not in selected_docs]
save_documents(uploaded_documents)
remaining_docs = [doc["name"] for doc in uploaded_documents]
return f"Deleted documents: {', '.join(selected_docs)}", remaining_docs
return "No documents to delete.", []
def get_response(query, temperature=0.2):
if not query.strip():
return "Please enter a question."
uploaded_documents = load_documents()
selected_docs = [doc["name"] for doc in uploaded_documents if doc["selected"]]
if not selected_docs:
return "Please select at least one document to search through."
embed = get_embeddings()
if not os.path.exists("faiss_database"):
return "No documents available. Please upload PDF documents first."
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
# Filter documents
filtered_docs = []
for doc in database.docstore._dict.values():
if isinstance(doc, Document) and doc.metadata.get("source") in selected_docs:
filtered_docs.append(doc)
if not filtered_docs:
return "No relevant information found in the selected documents."
filtered_db = FAISS.from_documents(filtered_docs, embed)
retriever = filtered_db.as_retriever(search_kwargs={"k": 5})
relevant_docs = retriever.get_relevant_documents(query)
context_str = "\n".join([doc.page_content for doc in relevant_docs])
messages = [
{"role": "system", "content": "You are a helpful assistant that provides accurate answers based on the given context."},
{"role": "user", "content": f"Context:\n{context_str}\n\nQuestion: {query}\n\nProvide a comprehensive answer based only on the given context."}
]
client = InferenceClient(DEFAULT_MODEL, token=HF_TOKEN)
try:
response = client.chat_completion(
messages=messages,
max_tokens=1000,
temperature=temperature,
top_p=0.9,
)
return response.choices[0].message.content
except Exception as e:
return f"Error generating response: {str(e)}"
def create_interface():
with gr.Blocks(title="PDF Question Answering System") as app:
gr.Markdown("# PDF Question Answering System")
with gr.Row():
with gr.Column():
files = gr.File(
label="Upload PDF Documents",
file_types=[".pdf"],
file_count="multiple"
)
upload_button = gr.Button("Upload and Process")
with gr.Column():
doc_status = gr.Textbox(label="Status", interactive=False)
doc_list = gr.Checkboxgroup(
label="Available Documents",
choices=[],
interactive=True
)
delete_button = gr.Button("Delete Selected Documents")
with gr.Row():
with gr.Column():
question = gr.Textbox(
label="Ask a question about the documents",
placeholder="Enter your question here..."
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.1,
label="Temperature (Higher values make the output more random)"
)
submit_button = gr.Button("Submit Question")
with gr.Column():
answer = gr.Textbox(
label="Answer",
interactive=False,
lines=10
)
# Event handlers
upload_button.click(
fn=process_uploaded_files,
inputs=[files],
outputs=[doc_status, doc_list]
)
delete_button.click(
fn=delete_documents,
inputs=[doc_list],
outputs=[doc_status, doc_list]
)
submit_button.click(
fn=get_response,
inputs=[question, temperature],
outputs=[answer]
)
# Add keyboard shortcut for submitting questions
question.submit(
fn=get_response,
inputs=[question, temperature],
outputs=[answer]
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch(
server_name="0.0.0.0", # Makes the app accessible from other machines
server_port=7860, # Specify port
share=True # Creates a public URL
)