Ultronprime commited on
Commit
160e875
·
verified ·
1 Parent(s): d4cee85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -37
app.py CHANGED
@@ -8,8 +8,14 @@ from dataclasses import dataclass
8
  from datetime import datetime
9
  from pathlib import Path
10
  import gc
 
 
 
 
 
 
 
11
  import zipfile
12
- import shutil
13
  import tempfile
14
 
15
  # Custom Exception Class
@@ -20,23 +26,27 @@ class GPUQuotaExceededError(Exception):
20
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
21
  CHUNK_SIZE = 500
22
  BATCH_SIZE = 32
23
- CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
24
- PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/workspace")
25
 
26
- # Directories setup
27
- os.makedirs(PERSISTENT_PATH, exist_ok=True)
 
 
28
  TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
29
- os.makedirs(TEMP_DIR, exist_ok=True)
 
30
  OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
31
- os.makedirs(OUTPUTS_DIR, exist_ok=True)
32
 
33
- # Logging Setup
34
  LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
35
- os.makedirs(LOG_DIR, exist_ok=True)
36
- LOG_FILE = os.path.join(LOG_DIR, "app.log")
 
 
 
37
 
 
38
  logging.basicConfig(
39
- filename=LOG_FILE,
40
  level=logging.INFO,
41
  format="%(asctime)s - %(levelname)s - %(message)s",
42
  )
@@ -49,7 +59,7 @@ def initialize_model():
49
  global model
50
  try:
51
  if model is None:
52
- model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR)
53
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
54
  return True
55
  except requests.exceptions.ConnectionError as e:
@@ -132,7 +142,7 @@ def process_files(files):
132
  all_embeddings = []
133
  for i in range(0, len(all_chunks), BATCH_SIZE):
134
  batch = all_chunks[i:i+BATCH_SIZE]
135
- embeddings = handle_gpu_operation(lambda: get_model().encode(batch))
136
  all_embeddings.extend(embeddings)
137
 
138
  # Save results to OUTPUTS_DIR
@@ -157,8 +167,8 @@ def process_files(files):
157
  @spaces.GPU
158
  def semantic_search(query, top_k=5):
159
  global model
160
- if model is None: # Check if model is initialized
161
- if not initialize_model(): # Initialize only if needed and within GPU context
162
  return "Model initialization failed. Please try again."
163
 
164
  try:
@@ -168,10 +178,13 @@ def semantic_search(query, top_k=5):
168
  # Load stored chunks from OUTPUTS_DIR
169
  with open(os.path.join(OUTPUTS_DIR, "chunks.txt"), "r", encoding="utf-8") as f:
170
  chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
171
- chunks = [c for c in chunks if c.strip()] # Remove empty chunks
172
 
173
  # Get query embedding
174
- query_embedding = handle_gpu_operation(lambda: get_model().encode([query]))[0]
 
 
 
175
 
176
  # Calculate similarities
177
  similarities = np.dot(stored_embeddings, query_embedding) / (
@@ -203,7 +216,9 @@ def search_and_format(query, num_results):
203
 
204
  def browse_outputs():
205
  try:
206
- os.startfile(OUTPUTS_DIR) # For Windows, on Linux use subprocess.run(['xdg-open', OUTPUTS_DIR])
 
 
207
  except Exception as e:
208
  logger.error(f"Error opening file browser: {str(e)}")
209
  return "Error opening file browser"
@@ -215,16 +230,13 @@ def download_results_from_disk():
215
  os.path.join(OUTPUTS_DIR, "chunks.txt")
216
  ]
217
 
218
- # Create a temporary zip file
219
- temp_dir = tempfile.gettempdir()
220
- zip_path = os.path.join(temp_dir, "results.zip")
221
-
222
- with zipfile.ZipFile(zip_path, 'w') as zipf:
223
- for file in output_files:
224
- if os.path.exists(file):
225
- zipf.write(file, os.path.basename(file))
226
-
227
- return zip_path
228
  except Exception as e:
229
  logger.error(f"Error creating download: {str(e)}")
230
  return "Error creating download file"
@@ -271,13 +283,13 @@ def create_gradio_interface():
271
  )
272
 
273
  # Download Results Button
274
- download_results_button = gr.Button("⬇️ Download Search Results")
275
  download_results_button.click(
276
  fn=download_results_from_disk,
277
  outputs=[gr.File(label="Download Results")]
278
  )
279
 
280
- with gr.Tab("_FILES_"):
281
  # Browse Outputs Button
282
  browse_button = gr.Button("📁 Browse Outputs", variant="primary")
283
  browse_button.click(
@@ -285,13 +297,6 @@ def create_gradio_interface():
285
  outputs=None
286
  )
287
 
288
- # Download All Results Button
289
- download_all_button = gr.Button("⬇️ Download All Results", variant="primary")
290
- download_all_button.click(
291
- fn=download_results_from_disk,
292
- outputs=[gr.File(label="Download All Results")]
293
- )
294
-
295
  process_button.click(
296
  process_files,
297
  inputs=[file_input],
 
8
  from datetime import datetime
9
  from pathlib import Path
10
  import gc
11
+ import torch
12
+ from torch.cuda.amp import autocast
13
+ from transformers import AutoModel, AutoTokenizer
14
+ from sentence_transformers import SentenceTransformer
15
+ import numpy as np
16
+ import requests
17
+ from charset_normalizer import from_bytes
18
  import zipfile
 
19
  import tempfile
20
 
21
  # Custom Exception Class
 
26
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
27
  CHUNK_SIZE = 500
28
  BATCH_SIZE = 32
 
 
29
 
30
+ # Persistent storage directories
31
+ PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data")
32
+ os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)
33
+
34
  TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
35
+ os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)
36
+
37
  OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
38
+ os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)
39
 
 
40
  LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
41
+ os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
42
+
43
+ # Set Hugging Face cache directory to PERSISTENT_PATH
44
+ os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
45
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
46
 
47
+ # Logging Setup
48
  logging.basicConfig(
49
+ filename=os.path.join(LOG_DIR, "app.log"),
50
  level=logging.INFO,
51
  format="%(asctime)s - %(levelname)s - %(message)s",
52
  )
 
59
  global model
60
  try:
61
  if model is None:
62
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=os.path.join(PERSISTENT_PATH, "models"))
63
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
64
  return True
65
  except requests.exceptions.ConnectionError as e:
 
142
  all_embeddings = []
143
  for i in range(0, len(all_chunks), BATCH_SIZE):
144
  batch = all_chunks[i:i+BATCH_SIZE]
145
+ embeddings = handle_gpu_operation(lambda: get_model().encode(batch)) if model else []
146
  all_embeddings.extend(embeddings)
147
 
148
  # Save results to OUTPUTS_DIR
 
167
  @spaces.GPU
168
  def semantic_search(query, top_k=5):
169
  global model
170
+ if model is None:
171
+ if not initialize_model():
172
  return "Model initialization failed. Please try again."
173
 
174
  try:
 
178
  # Load stored chunks from OUTPUTS_DIR
179
  with open(os.path.join(OUTPUTS_DIR, "chunks.txt"), "r", encoding="utf-8") as f:
180
  chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
181
+ chunks = [c for c in chunks if c.strip()]
182
 
183
  # Get query embedding
184
+ if model:
185
+ query_embedding = handle_gpu_operation(lambda: get_model().encode([query]))[0]
186
+ else:
187
+ return "Model not initialized. Please process files first."
188
 
189
  # Calculate similarities
190
  similarities = np.dot(stored_embeddings, query_embedding) / (
 
216
 
217
  def browse_outputs():
218
  try:
219
+ # Attempt to open the OUTPUTS_DIR
220
+ os.startfile(OUTPUTS_DIR)
221
+ return "Opened outputs directory successfully"
222
  except Exception as e:
223
  logger.error(f"Error opening file browser: {str(e)}")
224
  return "Error opening file browser"
 
230
  os.path.join(OUTPUTS_DIR, "chunks.txt")
231
  ]
232
 
233
+ with tempfile.TemporaryDirectory() as temp_dir:
234
+ zip_path = os.path.join(temp_dir, "results.zip")
235
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
236
+ for file in output_files:
237
+ if os.path.exists(file):
238
+ zipf.write(file, os.path.basename(file))
239
+ return zip_path
 
 
 
240
  except Exception as e:
241
  logger.error(f"Error creating download: {str(e)}")
242
  return "Error creating download file"
 
283
  )
284
 
285
  # Download Results Button
286
+ download_results_button = gr.Button("⬇️ Download Results")
287
  download_results_button.click(
288
  fn=download_results_from_disk,
289
  outputs=[gr.File(label="Download Results")]
290
  )
291
 
292
+ with gr.Tab("Outputs"):
293
  # Browse Outputs Button
294
  browse_button = gr.Button("📁 Browse Outputs", variant="primary")
295
  browse_button.click(
 
297
  outputs=None
298
  )
299
 
 
 
 
 
 
 
 
300
  process_button.click(
301
  process_files,
302
  inputs=[file_input],