Ultronprime commited on
Commit
1b5d6e1
·
verified ·
1 Parent(s): 6c3e7e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -208
app.py CHANGED
@@ -1,234 +1,314 @@
1
  import os
2
- import time
3
- import json
4
- from pathlib import Path
5
- from typing import List
6
  import spaces
 
 
 
 
 
7
 
8
- import gradio as gr
9
  import torch
10
- from huggingface_hub import HfApi, hf_hub_download, create_repo, upload_file, CommitOperationAdd, login
11
- from transformers import pipeline, AutoTokenizer
12
- from datasets import Dataset
13
- from sklearn.decomposition import PCA
14
  import numpy as np
15
- import plotly.graph_objects as go
16
- from sklearn.manifold import TSNE
17
- import traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # --- User Configuration ---
20
- HF_USERNAME = os.getenv("HF_USERNAME")
21
- DATASET_ID = f"{HF_USERNAME}/rag-embeddings" # Dataset repo name
22
- MODEL_ID = f"{HF_USERNAME}/my-test-model" # Model repo name
23
- API_TOKEN = os.getenv("HF_TOKEN") # Read from environment variable
24
-
25
- if not HF_USERNAME:
26
- raise ValueError("Please set the HF_USERNAME environment variable with your Hugging Face username.")
27
- if not API_TOKEN:
28
- raise ValueError("Please set the HF_TOKEN environment variable with your Hugging Face API token.")
29
-
30
- # --- Helper Functions ---
31
- def get_text_from_files(file_paths):
32
- all_text = []
33
- for filepath in file_paths:
34
- try:
35
- with open(filepath.name, "r", encoding="utf-8") as file:
36
- all_text.append(file.read())
37
- except Exception as e:
38
- print(f"Error reading file: {file.name} with error: {e}. Skipping file.")
39
- return all_text
40
-
41
- def get_embeddings(texts, model_id="sentence-transformers/all-mpnet-base-v2"):
42
  try:
43
- model = pipeline('feature-extraction', model=model_id, device="cuda")
44
- embeddings = model(texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  except Exception as e:
46
- print(f"Error during embeddings: {e}. Please check your GPU configuration and model.")
47
- return None
48
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- def get_llm_response(query, context, model_id="HuggingFaceH4/zephyr-7b-beta"):
51
  try:
52
- tokenizer = AutoTokenizer.from_pretrained(model_id)
53
- model = pipeline("text-generation", model=model_id, device="cuda")
54
- prompt = f"""
55
- Answer the following question according to the provided context.
56
-
57
- Question: {query}
58
- Context: {context}
59
- Answer:
60
- """
61
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
62
- output = model(
63
- **inputs,
64
- max_new_tokens=250,
65
- do_sample=True,
66
- top_p=0.9,
67
- temperature=0.2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
- return tokenizer.decode(output[0]["generated_text"], skip_special_tokens=True)
70
 
71
  except Exception as e:
72
- print(f"Error during text generation {e}. Please check your settings")
73
- return f"There was an error. Please check settings and if the models are available: {str(e)}"
 
 
 
 
 
 
 
74
 
75
- def format_output(output):
76
- return output.strip()
 
77
 
78
- def fetch_from_store(query_embeddings, dataset_id):
79
-
80
- try:
81
- file_path = hf_hub_download(repo_id=dataset_id, filename="embeddings.json", repo_type="dataset", token=API_TOKEN)
82
- except Exception as e:
83
- return f"Couldn't find the embeddings on the Hub! Did you save them before? {str(e)}"
84
 
85
- with open(file_path, 'r') as f:
86
- dataset = json.load(f)
87
 
88
- all_similarities = []
89
- for text_embedding in dataset["embeddings"]:
90
- try:
91
- sim = torch.nn.functional.cosine_similarity(torch.tensor(query_embeddings), torch.tensor(text_embedding), dim=0)
92
- all_similarities.append(sim.item())
93
- except Exception as e:
94
- print (f"Error calculating similarity {e} skipping text entry")
95
 
96
- most_similar_index = all_similarities.index(max(all_similarities))
97
- return dataset["texts"][most_similar_index]
98
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- @spaces.GPU
101
- def rag_chain(question,files):
102
- # generate embedding for user input.
103
-
104
- if files is not None:
105
- texts = get_text_from_files(files)
106
- if texts:
107
- embeddings = get_embeddings(texts)
108
- if embeddings:
109
- upload_embeddings_to_hub(texts, embeddings, dataset_id=DATASET_ID)
110
- else:
111
- return "There was an error uploading the dataset."
112
-
113
-
114
- input_embedding = get_embeddings(texts=[question])
115
- # Get most relevant text:
116
- if input_embedding:
117
- context = fetch_from_store(input_embedding[0], dataset_id=DATASET_ID)
118
- if context:
119
- #Get the final output
120
- output = get_llm_response(question,context)
121
- return format_output(output)
122
- else:
123
- return "There was an error. Couldn't fetch a correct context. Is there embeddings in the Hub?"
124
- else:
125
- return "There was an error generating the embeddings. Try again"
126
-
127
-
128
- # --- Upload embedding to the Hub (only run one time) ---
129
- def upload_embeddings_to_hub(texts, embeddings, dataset_id):
130
- api = HfApi(token=API_TOKEN)
131
- try:
132
- create_repo(repo_id=dataset_id, repo_type="dataset", private=False)
133
- print(f"Dataset repo {dataset_id} created successfully!")
134
  except Exception as e:
135
- print(f"Dataset repo {dataset_id} already exists, {e}")
136
-
137
- dataset = {
138
- "texts": texts,
139
- "embeddings": embeddings
140
- }
141
-
142
- with open("embeddings.json","w") as outfile:
143
- json.dump(dataset, outfile)
144
-
145
- upload_file(
146
- path_or_fileobj="embeddings.json",
147
- path_in_repo="embeddings.json",
148
- repo_id=dataset_id,
149
- repo_type="dataset",
150
- token = API_TOKEN
151
- )
152
- print("Finished embeddings upload")
153
-
154
- def reduce_dimension_pca(embeddings, n_components=2):
155
- pca = PCA(n_components=n_components)
156
- reduced_embeddings = pca.fit_transform(np.array(embeddings))
157
- return reduced_embeddings
158
-
159
- def reduce_dimension_tsne(embeddings, n_components=2, perplexity = 30, n_iter = 300):
160
- tsne = TSNE(n_components=n_components, perplexity = perplexity, n_iter = n_iter, random_state=42)
161
- reduced_embeddings = tsne.fit_transform(np.array(embeddings))
162
- return reduced_embeddings
163
-
164
- def get_plotly_plot(texts, embeddings, method='PCA'):
165
- if method == 'PCA':
166
- reduced_embeddings = reduce_dimension_pca(embeddings)
167
- elif method == 'TSNE':
168
- reduced_embeddings = reduce_dimension_tsne(embeddings)
169
-
170
- fig = go.Figure(data=[go.Scatter(
171
- x=reduced_embeddings[:, 0],
172
- y=reduced_embeddings[:, 1],
173
- mode='markers+text',
174
- text=texts,
175
- textposition="bottom center",
176
- marker=dict(size=10,
177
- color=list(range(len(texts))),
178
- colorscale='Viridis',
179
- showscale=True,
180
- )
181
- )])
182
-
183
- fig.update_layout(title=f'Document Embeddings Visualization using {method}')
184
- return fig
185
 
186
  @spaces.GPU
187
- def visualize_data(files, dataset_id):
188
- if not files:
189
- return "No files uploaded to visualize", None, None
 
190
 
191
  try:
192
- file_path = hf_hub_download(repo_id=dataset_id, filename="embeddings.json", repo_type="dataset", token=API_TOKEN)
 
 
 
 
 
 
 
193
  except Exception as e:
194
- return f"Couldn't find the embeddings on the Hub! Did you save them before? {str(e)}", None, None
195
-
196
- with open(file_path, 'r') as f:
197
- dataset = json.load(f)
198
-
199
- texts = dataset["texts"]
200
- embeddings = dataset["embeddings"]
201
-
202
- fig_pca = get_plotly_plot(texts, embeddings, method='PCA')
203
- fig_tsne = get_plotly_plot(texts, embeddings, method='TSNE')
204
-
205
- return fig_pca, fig_tsne
206
-
207
-
208
- # --- Main Gradio Interface ---
209
- with gr.Blocks() as demo:
210
- with gr.Tab("Chat"):
211
- chatbot_input = gr.Textbox(placeholder="Ask me something...")
212
- chatbot_output = gr.Textbox()
213
- with gr.Row():
214
- chatbot_files = gr.File(file_types=['.txt'], file_count = "multiple", label="Upload text files")
215
- chatbot_button = gr.Button("Submit")
216
- chatbot_button.click(rag_chain, inputs=[chatbot_input, chatbot_files], outputs=chatbot_output)
217
- with gr.Tab("Visualization"):
218
- visualization_files = gr.File(file_types=['.txt'], file_count = "multiple", label="Upload text files")
219
- with gr.Row():
220
- submit_button = gr.Button("Visualize data")
221
- with gr.Row():
222
- plotly_output_pca = gr.Plot()
223
- with gr.Row():
224
- plotly_output_tsne = gr.Plot()
225
-
226
- submit_button.click(visualize_data, inputs=visualization_files, outputs=[plotly_output_pca, plotly_output_tsne])
227
-
228
- demo.launch(server_name="0.0.0.0")
229
-
230
- # --- Upload embeddings to Hub(one time execution)---
231
- # local_data_path = "data" # Please set this path to where your data is!
232
- # texts = get_text_from_files(os.listdir(local_data_path))
233
- # embeddings = get_embeddings(texts)
234
- # upload_embeddings_to_hub(texts, embeddings, dataset_id=DATASET_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
+ import logging
4
+ import traceback
 
5
  import spaces
6
+ from typing import Optional, List
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ import gc
11
 
 
12
  import torch
13
+ from torch.cuda.amp import autocast
14
+ from transformers import AutoModel, AutoTokenizer
15
+ from sentence_transformers import SentenceTransformer
16
+ from charset_normalizer import from_bytes
17
  import numpy as np
18
+ import requests
19
+
20
+ # Custom Exception Class
21
+ class GPUQuotaExceededError(Exception):
22
+ pass
23
+
24
+ # Constants
25
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
26
+ CHUNK_SIZE = 500
27
+ BATCH_SIZE = 32
28
+ CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
29
+ PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data")
30
+
31
+ # Create directories
32
+ os.makedirs(CACHE_DIR, exist_ok=True)
33
+ os.makedirs(PERSISTENT_PATH, exist_ok=True)
34
+
35
+ # Logging Setup
36
+ LOG_DIR = os.getenv("LOG_DIR", "/data/logs")
37
+ os.makedirs(LOG_DIR, exist_ok=True)
38
+ LOG_FILE = Path(LOG_DIR) / "app.log"
39
+
40
+ logging.basicConfig(
41
+ filename=str(LOG_FILE),
42
+ level=logging.INFO,
43
+ format="%(asctime)s - %(levelname)s - %(message)s",
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Model initialization
48
+ model = None
49
+
50
+ def initialize_model():
51
+ global model
52
+ try:
53
+ if model is None:
54
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR)
55
+ logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
56
+ return True
57
+ except requests.exceptions.ConnectionError as e:
58
+ logger.error(f"Connection error during model download: {str(e)}\n{traceback.format_exc()}")
59
+ return False
60
+ except Exception as e:
61
+ logger.error(f"Model initialization failed: {str(e)}\n{traceback.format_exc()}")
62
+ return False
63
 
64
+ @spaces.GPU
65
+ def handle_gpu_operation(func):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
+ start_time = datetime.now()
68
+ with autocast(enabled=torch.cuda.is_available()):
69
+ result = func()
70
+ end_time = datetime.now()
71
+ duration = (end_time - start_time).total_seconds()
72
+ logger.info(f"GPU operation completed in {duration:.2f}s")
73
+ return result
74
+ except RuntimeError as e:
75
+ if "CUDA out of memory" in str(e):
76
+ torch.cuda.empty_cache()
77
+ logger.error(f"GPU memory error: {str(e)}")
78
+ raise GPUQuotaExceededError("GPU memory limit exceeded. Please try with a smaller batch.")
79
+ else:
80
+ logger.error(f"GPU runtime error: {str(e)}")
81
+ raise
82
  except Exception as e:
83
+ if "quota exceeded" in str(e).lower():
84
+ logger.error(f"GPU quota exceeded: {str(e)}")
85
+ raise GPUQuotaExceededError("GPU quota exceeded. Please wait a few minutes before trying again.")
86
+ else:
87
+ logger.error(f"Unexpected GPU error: {str(e)}")
88
+ raise
89
+
90
+ def get_model():
91
+ global model
92
+ if model is None:
93
+ if torch.cuda.is_available():
94
+ initialize_model()
95
+ else:
96
+ logger.warning("Attempted to initialize model outside GPU context, deferring.")
97
+ return None
98
+ return model
99
+
100
+ @spaces.GPU
101
+ def process_files(files):
102
+ if not files:
103
+ return "Please upload one or more .txt files.", "", ""
104
 
 
105
  try:
106
+ if not initialize_model():
107
+ return "Failed to initialize the model. Please try again.", "", ""
108
+
109
+ valid_files = [f for f in files if f.name.lower().endswith('.txt')]
110
+ if not valid_files:
111
+ return "No .txt files found in upload. Please ensure you upload .txt files.", "", ""
112
+
113
+ all_chunks = []
114
+ processed_files = 0
115
+
116
+ for file in valid_files:
117
+ try:
118
+ with open(file.name, 'rb') as f:
119
+ content = f.read()
120
+ detected_encoding = from_bytes(content).best().encoding
121
+ decoded_content = content.decode(detected_encoding, errors='ignore')
122
+
123
+ chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)]
124
+ all_chunks.extend(chunks)
125
+ processed_files += 1
126
+ logger.info(f"Processed file: {file.name}")
127
+ except Exception as e:
128
+ logger.error(f"Error processing file {file.name}: {str(e)}")
129
+
130
+ if not all_chunks:
131
+ return "No valid content found in the uploaded .txt files.", "", ""
132
+
133
+ # Generate embeddings in batches
134
+ all_embeddings = []
135
+ for i in range(0, len(all_chunks), BATCH_SIZE):
136
+ batch = all_chunks[i:i+BATCH_SIZE]
137
+ embeddings = handle_gpu_operation(lambda: get_model().encode(batch))
138
+ all_embeddings.extend(embeddings)
139
+
140
+ # Save results
141
+ np.save(f"{PERSISTENT_PATH}/embeddings.npy", np.array(all_embeddings))
142
+
143
+ with open(f"{PERSISTENT_PATH}/chunks.txt", "w", encoding="utf-8") as f:
144
+ for chunk in all_chunks:
145
+ f.write(chunk + "\n===CHUNK_SEPARATOR===\n")
146
+
147
+ return (
148
+ f"Successfully processed {processed_files} files. Generated {len(all_embeddings)} embeddings from {len(all_chunks)} chunks.",
149
+ "",
150
+ ""
151
  )
 
152
 
153
  except Exception as e:
154
+ logger.error(f"Processing failed: {str(e)}")
155
+ return f"Error processing files: {str(e)}", "", ""
156
+
157
+ @spaces.GPU
158
+ def semantic_search(query, top_k=5):
159
+ global model
160
+ if model is None: # Check if model is initialized
161
+ if not initialize_model(): # Initialize only if needed and within GPU context
162
+ return "Model initialization failed. Please try again."
163
 
164
+ try:
165
+ # Load saved embeddings
166
+ stored_embeddings = np.load(f"{PERSISTENT_PATH}/embeddings.npy")
167
 
168
+ # Load stored chunks
169
+ with open(f"{PERSISTENT_PATH}/chunks.txt", "r", encoding="utf-8") as f:
170
+ chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
171
+ chunks = [c for c in chunks if c.strip()] # Remove empty chunks
 
 
172
 
173
+ # Get query embedding
174
+ query_embedding = handle_gpu_operation(lambda: get_model().encode([query]))[0] # Use get_model() to get the model
175
 
176
+ # Calculate similarities
177
+ similarities = np.dot(stored_embeddings, query_embedding) / (
178
+ np.linalg.norm(stored_embeddings, axis=1) * np.linalg.norm(query_embedding)
179
+ )
 
 
 
180
 
181
+ # Get top results
182
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
183
 
184
+ # Format results
185
+ results = []
186
+ for idx in top_indices:
187
+ results.append(f"""
188
+ Similarity: {similarities[idx]:.3f}
189
+ Content: {chunks[idx]}
190
+ -------------------
191
+ """)
192
+
193
+ return "\n".join(results)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  except Exception as e:
196
+ logger.error(f"Search error: {str(e)}")
197
+ return f"Search error occurred: {str(e)}"
198
+
199
+ def search_and_format(query, num_results):
200
+ if not query.strip():
201
+ return "Please enter a search query"
202
+ return semantic_search(query, top_k=num_results)
203
+
204
+ def download_results(text):
205
+ if not text:
206
+ return None
207
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
208
+ filename = f"search_results_{timestamp}.txt"
209
+ with open(filename, "w", encoding="utf-8") as f:
210
+ f.write(text)
211
+ return filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  @spaces.GPU
214
+ def safe_generate_embedding(text):
215
+ global model
216
+ if model is None: # Check if model is initialized
217
+ initialize_model() # Initialize only if needed and within GPU context
218
 
219
  try:
220
+ embedding = handle_gpu_operation(
221
+ lambda: get_model().encode([text])[0].tolist() # Use get_model() to get the model
222
+ )
223
+ return embedding, "", False
224
+ except GPUQuotaExceededError as e:
225
+ error_msg = str(e)
226
+ logger.error(error_msg)
227
+ return "", error_msg, True
228
  except Exception as e:
229
+ error_msg = f"Error generating embedding: {str(e)}"
230
+ logger.error(error_msg)
231
+ return "", error_msg, True
232
+
233
+ def download_embeddings():
234
+ embeddings_path = f"{PERSISTENT_PATH}/embeddings.npy"
235
+ if not os.path.exists(embeddings_path):
236
+ return None
237
+ return embeddings_path
238
+
239
+ def create_gradio_interface():
240
+ with gr.Blocks() as demo:
241
+ gr.Markdown("## Text Chunk Embeddings Generator")
242
+
243
+ error_box = gr.Textbox(visible=False, label="Status/Error Messages")
244
+
245
+ with gr.Row():
246
+ file_input = gr.File(
247
+ label="Upload Text Files",
248
+ file_count="multiple",
249
+ file_types=[".txt"]
250
+ )
251
+
252
+ process_button = gr.Button("Generate Embeddings")
253
+ output_text = gr.Textbox(label="Status")
254
+
255
+ with gr.Tab("Search"):
256
+ query_input = gr.Textbox(
257
+ label="Enter your search query",
258
+ placeholder="Enter text to search through your documents..."
259
+ )
260
+ top_k = gr.Slider(
261
+ minimum=1,
262
+ maximum=20,
263
+ value=5,
264
+ step=1,
265
+ label="Number of results to return"
266
+ )
267
+ search_button = gr.Button("🔍 Search")
268
+ results_output = gr.Textbox(
269
+ label="Search Results",
270
+ lines=10,
271
+ show_copy_button=True
272
+ )
273
+ download_button = gr.Button("⬇️ Download Results")
274
+
275
+ search_button.click(
276
+ fn=search_and_format,
277
+ inputs=[query_input, top_k],
278
+ outputs=results_output
279
+ )
280
+
281
+ download_button.click(
282
+ fn=download_results,
283
+ inputs=[results_output],
284
+ outputs=[gr.File(label="Download Search Results")]
285
+ )
286
+
287
+ with gr.Tab("Inspect Embeddings"):
288
+ embed_input = gr.Textbox(label="Enter Text for Embedding")
289
+ embed_button = gr.Button("Generate Embedding")
290
+ embed_output = gr.Textbox(label="Embedding Vector", lines=5)
291
+
292
+ embed_button.click(
293
+ safe_generate_embedding,
294
+ inputs=[embed_input],
295
+ outputs=[embed_output, error_box, error_box]
296
+ )
297
+
298
+ download_embeddings_button = gr.Button("⬇️ Download Embeddings")
299
+ download_embeddings_button.click(
300
+ fn=download_embeddings,
301
+ outputs=[gr.File(label="Download Embeddings")]
302
+ )
303
+
304
+ process_button.click(
305
+ process_files,
306
+ inputs=[file_input],
307
+ outputs=[output_text, error_box, error_box]
308
+ )
309
+
310
+ return demo
311
+
312
+ if __name__ == "__main__":
313
+ demo = create_gradio_interface()
314
+ demo.launch(server_name="0.0.0.0")