File size: 8,453 Bytes
517a6a8
73a7410
82cd6c2
73a7410
 
 
 
 
 
4c57776
73a7410
 
 
 
 
 
 
 
 
 
4c57776
73a7410
 
4c57776
 
 
 
 
73a7410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c57776
73a7410
 
0b862cc
4c57776
73a7410
 
 
 
 
4c57776
0b862cc
73a7410
 
0b862cc
 
73a7410
 
0b862cc
73a7410
 
0b862cc
 
73a7410
 
 
 
 
 
 
 
 
 
 
 
0b862cc
73a7410
4c57776
0b862cc
73a7410
 
517a6a8
73a7410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c57776
 
73a7410
 
517a6a8
73a7410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82cd6c2
73a7410
 
 
 
 
 
 
 
 
 
 
 
 
517a6a8
73a7410
 
 
 
82cd6c2
4c57776
82cd6c2
73a7410
 
 
 
 
 
 
 
 
 
 
 
 
4c57776
 
 
 
73a7410
 
 
 
 
4c57776
73a7410
 
82cd6c2
73a7410
4c57776
 
 
 
 
50a6cc0
73a7410
 
4c57776
73a7410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c57776
 
 
 
 
 
 
73a7410
 
23a8177
 
73a7410
4c57776
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import os
import json
import logging
import shutil
import gradio as gr
from typing import List
from tempfile import NamedTemporaryFile
from huggingface_hub import InferenceClient
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.docstore.document import Document

# Setup logging
logging.basicConfig(level=logging.INFO)

# Constants
DOCUMENTS_FILE = "uploaded_documents.json"
DEFAULT_MODEL = "@cf/meta/llama-2-7b-chat"
HF_TOKEN = os.getenv("HF_API_TOKEN")  # Make sure to set this environment variable
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"

def get_embeddings():
    return HuggingFaceEmbeddings(
        model_name=EMBED_MODEL,
        model_kwargs={'device': 'cpu'},
        encode_kwargs={'normalize_embeddings': True}
    )

def load_documents():
    if os.path.exists(DOCUMENTS_FILE):
        with open(DOCUMENTS_FILE, "r") as f:
            return json.load(f)
    return []

def save_documents(documents):
    with open(DOCUMENTS_FILE, "w") as f:
        json.dump(documents, f)

def load_document(file: NamedTemporaryFile) -> List[Document]:
    """Loads and splits the document into pages using PyPDF."""
    loader = PyPDFLoader(file.name)
    return loader.load_and_split()

def process_uploaded_files(files):
    if not files:
        return "Please upload at least one file.", []
    
    files_list = [files] if not isinstance(files, list) else files
    embed = get_embeddings()
    uploaded_documents = load_documents()
    total_chunks = 0
    
    all_data = []
    for file in files_list:
        try:
            data = load_document(file)
            if not data:
                continue
            
            all_data.extend(data)
            total_chunks += len(data)
            
            if not any(doc["name"] == file.name for doc in uploaded_documents):
                uploaded_documents.append({"name": file.name, "selected": True})
                
        except Exception as e:
            logging.error(f"Error processing file {file.name}: {str(e)}")
    
    if not all_data:
        return "No valid data could be extracted from the uploaded files.", []
    
    try:
        if os.path.exists("faiss_database"):
            database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
            database.add_documents(all_data)
        else:
            database = FAISS.from_documents(all_data, embed)
        database.save_local("faiss_database")
        
        save_documents(uploaded_documents)
        return f"Vector store updated successfully. Processed {total_chunks} chunks.", [doc["name"] for doc in uploaded_documents]
        
    except Exception as e:
        return f"Error updating vector store: {str(e)}", []

def delete_documents(selected_docs):
    if not selected_docs:
        return "No documents selected for deletion.", []
    
    uploaded_documents = load_documents()
    embed = get_embeddings()
    
    if os.path.exists("faiss_database"):
        database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
        
        docs_to_keep = []
        for doc in database.docstore._dict.values():
            if doc.metadata.get("source") not in selected_docs:
                docs_to_keep.append(doc)
        
        if not docs_to_keep:
            shutil.rmtree("faiss_database")
        else:
            new_database = FAISS.from_documents(docs_to_keep, embed)
            new_database.save_local("faiss_database")
        
        uploaded_documents = [doc for doc in uploaded_documents if doc["name"] not in selected_docs]
        save_documents(uploaded_documents)
        
        remaining_docs = [doc["name"] for doc in uploaded_documents]
        return f"Deleted documents: {', '.join(selected_docs)}", remaining_docs
    
    return "No documents to delete.", []

def get_response(query, temperature=0.2):
    if not query.strip():
        return "Please enter a question."
    
    uploaded_documents = load_documents()
    selected_docs = [doc["name"] for doc in uploaded_documents if doc["selected"]]
    
    if not selected_docs:
        return "Please select at least one document to search through."
    
    embed = get_embeddings()
    if not os.path.exists("faiss_database"):
        return "No documents available. Please upload PDF documents first."
    
    database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
    
    # Filter documents
    filtered_docs = []
    for doc in database.docstore._dict.values():
        if isinstance(doc, Document) and doc.metadata.get("source") in selected_docs:
            filtered_docs.append(doc)
    
    if not filtered_docs:
        return "No relevant information found in the selected documents."
    
    filtered_db = FAISS.from_documents(filtered_docs, embed)
    retriever = filtered_db.as_retriever(search_kwargs={"k": 5})
    relevant_docs = retriever.get_relevant_documents(query)
    
    context_str = "\n".join([doc.page_content for doc in relevant_docs])
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant that provides accurate answers based on the given context."},
        {"role": "user", "content": f"Context:\n{context_str}\n\nQuestion: {query}\n\nProvide a comprehensive answer based only on the given context."}
    ]
    
    client = InferenceClient(DEFAULT_MODEL, token=HF_TOKEN)
    
    try:
        response = client.chat_completion(
            messages=messages,
            max_tokens=1000,
            temperature=temperature,
            top_p=0.9,
        )
        return response.choices[0].message.content
    except Exception as e:
        return f"Error generating response: {str(e)}"

def create_interface():
    with gr.Blocks(title="PDF Question Answering System") as app:
        gr.Markdown("# PDF Question Answering System")
        
        with gr.Row():
            with gr.Column():
                files = gr.File(
                    label="Upload PDF Documents",
                    file_types=[".pdf"],
                    file_count="multiple"
                )
                upload_button = gr.Button("Upload and Process")
                
            with gr.Column():
                doc_status = gr.Textbox(label="Status", interactive=False)
                doc_list = gr.Checkboxgroup(
                    label="Available Documents",
                    choices=[],
                    interactive=True
                )
                delete_button = gr.Button("Delete Selected Documents")
        
        with gr.Row():
            with gr.Column():
                question = gr.Textbox(
                    label="Ask a question about the documents",
                    placeholder="Enter your question here..."
                )
                temperature = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.2,
                    step=0.1,
                    label="Temperature (Higher values make the output more random)"
                )
                submit_button = gr.Button("Submit Question")
            
            with gr.Column():
                answer = gr.Textbox(
                    label="Answer",
                    interactive=False,
                    lines=10
                )
        
        # Event handlers
        upload_button.click(
            fn=process_uploaded_files,
            inputs=[files],
            outputs=[doc_status, doc_list]
        )
        
        delete_button.click(
            fn=delete_documents,
            inputs=[doc_list],
            outputs=[doc_status, doc_list]
        )
        
        submit_button.click(
            fn=get_response,
            inputs=[question, temperature],
            outputs=[answer]
        )
        
        # Add keyboard shortcut for submitting questions
        question.submit(
            fn=get_response,
            inputs=[question, temperature],
            outputs=[answer]
        )
    
    return app

if __name__ == "__main__":
    app = create_interface()
    app.launch(
        server_name="0.0.0.0",  # Makes the app accessible from other machines
        server_port=7860,       # Specify port
        share=True             # Creates a public URL
    )