Spaces:
Runtime error
Runtime error
import os | |
import json | |
import logging | |
import shutil | |
import gradio as gr | |
from typing import List | |
from tempfile import NamedTemporaryFile | |
from huggingface_hub import InferenceClient | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.docstore.document import Document | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
# Constants | |
DOCUMENTS_FILE = "uploaded_documents.json" | |
DEFAULT_MODEL = "@cf/meta/llama-2-7b-chat" | |
HF_TOKEN = os.getenv("HF_API_TOKEN") # Make sure to set this environment variable | |
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
def get_embeddings(): | |
return HuggingFaceEmbeddings( | |
model_name=EMBED_MODEL, | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
def load_documents(): | |
if os.path.exists(DOCUMENTS_FILE): | |
with open(DOCUMENTS_FILE, "r") as f: | |
return json.load(f) | |
return [] | |
def save_documents(documents): | |
with open(DOCUMENTS_FILE, "w") as f: | |
json.dump(documents, f) | |
def load_document(file: NamedTemporaryFile) -> List[Document]: | |
"""Loads and splits the document into pages using PyPDF.""" | |
loader = PyPDFLoader(file.name) | |
return loader.load_and_split() | |
def process_uploaded_files(files): | |
if not files: | |
return "Please upload at least one file.", [] | |
files_list = [files] if not isinstance(files, list) else files | |
embed = get_embeddings() | |
uploaded_documents = load_documents() | |
total_chunks = 0 | |
all_data = [] | |
for file in files_list: | |
try: | |
data = load_document(file) | |
if not data: | |
continue | |
all_data.extend(data) | |
total_chunks += len(data) | |
if not any(doc["name"] == file.name for doc in uploaded_documents): | |
uploaded_documents.append({"name": file.name, "selected": True}) | |
except Exception as e: | |
logging.error(f"Error processing file {file.name}: {str(e)}") | |
if not all_data: | |
return "No valid data could be extracted from the uploaded files.", [] | |
try: | |
if os.path.exists("faiss_database"): | |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) | |
database.add_documents(all_data) | |
else: | |
database = FAISS.from_documents(all_data, embed) | |
database.save_local("faiss_database") | |
save_documents(uploaded_documents) | |
return f"Vector store updated successfully. Processed {total_chunks} chunks.", [doc["name"] for doc in uploaded_documents] | |
except Exception as e: | |
return f"Error updating vector store: {str(e)}", [] | |
def delete_documents(selected_docs): | |
if not selected_docs: | |
return "No documents selected for deletion.", [] | |
uploaded_documents = load_documents() | |
embed = get_embeddings() | |
if os.path.exists("faiss_database"): | |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) | |
docs_to_keep = [] | |
for doc in database.docstore._dict.values(): | |
if doc.metadata.get("source") not in selected_docs: | |
docs_to_keep.append(doc) | |
if not docs_to_keep: | |
shutil.rmtree("faiss_database") | |
else: | |
new_database = FAISS.from_documents(docs_to_keep, embed) | |
new_database.save_local("faiss_database") | |
uploaded_documents = [doc for doc in uploaded_documents if doc["name"] not in selected_docs] | |
save_documents(uploaded_documents) | |
remaining_docs = [doc["name"] for doc in uploaded_documents] | |
return f"Deleted documents: {', '.join(selected_docs)}", remaining_docs | |
return "No documents to delete.", [] | |
def get_response(query, temperature=0.2): | |
if not query.strip(): | |
return "Please enter a question." | |
uploaded_documents = load_documents() | |
selected_docs = [doc["name"] for doc in uploaded_documents if doc["selected"]] | |
if not selected_docs: | |
return "Please select at least one document to search through." | |
embed = get_embeddings() | |
if not os.path.exists("faiss_database"): | |
return "No documents available. Please upload PDF documents first." | |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) | |
# Filter documents | |
filtered_docs = [] | |
for doc in database.docstore._dict.values(): | |
if isinstance(doc, Document) and doc.metadata.get("source") in selected_docs: | |
filtered_docs.append(doc) | |
if not filtered_docs: | |
return "No relevant information found in the selected documents." | |
filtered_db = FAISS.from_documents(filtered_docs, embed) | |
retriever = filtered_db.as_retriever(search_kwargs={"k": 5}) | |
relevant_docs = retriever.get_relevant_documents(query) | |
context_str = "\n".join([doc.page_content for doc in relevant_docs]) | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant that provides accurate answers based on the given context."}, | |
{"role": "user", "content": f"Context:\n{context_str}\n\nQuestion: {query}\n\nProvide a comprehensive answer based only on the given context."} | |
] | |
client = InferenceClient(DEFAULT_MODEL, token=HF_TOKEN) | |
try: | |
response = client.chat_completion( | |
messages=messages, | |
max_tokens=1000, | |
temperature=temperature, | |
top_p=0.9, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def create_interface(): | |
with gr.Blocks(title="PDF Question Answering System") as app: | |
gr.Markdown("# PDF Question Answering System") | |
with gr.Row(): | |
with gr.Column(): | |
files = gr.File( | |
label="Upload PDF Documents", | |
file_types=[".pdf"], | |
file_count="multiple" | |
) | |
upload_button = gr.Button("Upload and Process") | |
with gr.Column(): | |
doc_status = gr.Textbox(label="Status", interactive=False) | |
doc_list = gr.Checkboxgroup( | |
label="Available Documents", | |
choices=[], | |
interactive=True | |
) | |
delete_button = gr.Button("Delete Selected Documents") | |
with gr.Row(): | |
with gr.Column(): | |
question = gr.Textbox( | |
label="Ask a question about the documents", | |
placeholder="Enter your question here..." | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.2, | |
step=0.1, | |
label="Temperature (Higher values make the output more random)" | |
) | |
submit_button = gr.Button("Submit Question") | |
with gr.Column(): | |
answer = gr.Textbox( | |
label="Answer", | |
interactive=False, | |
lines=10 | |
) | |
# Event handlers | |
upload_button.click( | |
fn=process_uploaded_files, | |
inputs=[files], | |
outputs=[doc_status, doc_list] | |
) | |
delete_button.click( | |
fn=delete_documents, | |
inputs=[doc_list], | |
outputs=[doc_status, doc_list] | |
) | |
submit_button.click( | |
fn=get_response, | |
inputs=[question, temperature], | |
outputs=[answer] | |
) | |
# Add keyboard shortcut for submitting questions | |
question.submit( | |
fn=get_response, | |
inputs=[question, temperature], | |
outputs=[answer] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch( | |
server_name="0.0.0.0", # Makes the app accessible from other machines | |
server_port=7860, # Specify port | |
share=True # Creates a public URL | |
) |