Ultronprime commited on
Commit
d1cdc5f
·
verified ·
1 Parent(s): 4be0978

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -30
app.py CHANGED
@@ -1,30 +1,60 @@
1
  import os
2
  import gradio as gr
3
  import logging
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
 
 
 
 
 
6
  import torch
7
  from torch.amp import autocast
8
- from spaces import GPU
9
- import json # Import json for direct JSON output in UI
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Constants (Keep your HF token secure - use environment variables if possible for real deployments)
12
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
- CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
14
- PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/tmp/data")
15
- HF_TOKEN = "YOUR_HF_TOKEN" # REMEMBER TO REPLACE THIS - BEST TO USE ENVIRONMENT VARIABLE
 
 
 
16
 
17
- # Create directories (still useful to try, even if /tmp/ is ephemeral)
18
- os.makedirs(CACHE_DIR, exist_ok=True)
19
- os.makedirs(PERSISTENT_PATH, exist_ok=True)
20
 
21
- # Logging Setup (keep logging - it's helpful for debugging)
22
- LOG_DIR = os.getenv("LOG_DIR", "/data/logs")
23
- os.makedirs(LOG_DIR, exist_ok=True)
24
- LOG_FILE = LOG_DIR + "/app.log"
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  logging.basicConfig(
27
- filename=LOG_FILE,
28
  level=logging.INFO,
29
  format="%(asctime)s - %(levelname)s - %(message)s",
30
  )
@@ -32,25 +62,39 @@ logger = logging.getLogger(__name__)
32
 
33
  # Model initialization
34
  model = None
35
- model_initialization_error = "" # Global variable to store initialization error
36
 
37
  def initialize_model():
 
 
 
 
 
 
38
  global model, model_initialization_error
39
  try:
40
  if model is None:
41
- model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR, use_auth_token=HF_TOKEN)
 
 
 
42
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
43
  model_initialization_error = "" # Clear any previous error
44
  return True, "" # Return success and no error message
45
  return True, "" # Already initialized, return success and no error
46
- except Exception as e:
47
- error_msg = f"Model initialization failed: {str(e)}"
 
 
 
 
 
48
  logger.error(error_msg)
49
- model_initialization_error = error_msg # Store error message
50
- return False, error_msg # Return failure and error message
51
 
52
 
53
- @GPU()
54
  def generate_embedding(text, focus):
55
  global model, model_initialization_error
56
  if model is None:
@@ -69,11 +113,11 @@ def generate_embedding(text, focus):
69
  logger.error(error_msg)
70
  return "", error_msg
71
 
72
- @GPU()
73
  def save_embedding(embedding_json, name): # Expect JSON string as input from UI
74
  try:
75
  embedding = json.loads(embedding_json) # Parse JSON string back to list
76
- filepath = f"{PERSISTENT_PATH}/{name}.npy" # Construct full filepath
77
  np.save(filepath, np.array(embedding))
78
  return f"Embedding saved to: {filepath}" # Return filepath in status
79
  except Exception as e:
@@ -81,10 +125,10 @@ def save_embedding(embedding_json, name): # Expect JSON string as input from UI
81
  logger.error(error_msg)
82
  return error_msg
83
 
84
- @GPU()
85
  def convert_to_json(embedding_json, name): # Expect JSON string as input
86
  try:
87
- filepath = f"{PERSISTENT_PATH}/{name}.json" # Construct full filepath
88
  with open(filepath, "w") as f:
89
  f.write(embedding_json) # Directly write the JSON string
90
  return f"Embedding saved as JSON to: {filepath}" # Return filepath in status
@@ -93,7 +137,7 @@ def convert_to_json(embedding_json, name): # Expect JSON string as input
93
  logger.error(error_msg)
94
  return error_msg
95
 
96
- @GPU()
97
  def process_files(files, focus):
98
  global model, model_initialization_error
99
  if model is None:
@@ -106,7 +150,7 @@ def process_files(files, focus):
106
  file_statuses = [] # To track status for each file
107
  for file in files:
108
  try:
109
- with open(file.name, 'r') as f:
110
  text = f.read()
111
  with torch.amp.autocast('cuda'):
112
  embedding = model.encode([text])[0].tolist()
@@ -186,7 +230,7 @@ def create_gradio_interface():
186
  )
187
 
188
  download_button.click(
189
- lambda name: f"{PERSISTENT_PATH}/{name}.json" if name else None, # Handle empty name
190
  inputs=[save_name_input],
191
  outputs=[download_output]
192
  )
 
1
  import os
2
  import gradio as gr
3
  import logging
4
+ import traceback
5
+ import spaces
6
+ from typing import Optional, List
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ import gc
11
  import torch
12
  from torch.amp import autocast
13
+ from transformers import AutoModel, AutoTokenizer
14
+ from sentence_transformers import SentenceTransformer
15
+ import numpy as np
16
+ import requests
17
+ from charset_normalizer import from_bytes
18
+ import zipfile
19
+ import tempfile
20
+ import shutil
21
+
22
+ # Custom Exception Class (Keep this)
23
+ class GPUQuotaExceededError(Exception):
24
+ pass
25
 
26
+ # Constants (Modified Persistent Paths and Cache)
27
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
28
+ CHUNK_SIZE = 500
29
+ BATCH_SIZE = 32
30
+
31
+ # Set Persistent Storage Path (More Explicit Paths - from Worked Code)
32
+ PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage
33
+ os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)
34
 
35
+ # Define Subdirectories (More Explicit Paths)
36
+ TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
37
+ os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)
38
 
39
+ OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
40
+ os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)
 
 
41
 
42
+ NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache")
43
+ os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)
44
+
45
+ LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
46
+ os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
47
+
48
+ # Set Hugging Face cache directory to persistent storage (From Worked Code - Important!)
49
+ os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
50
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
51
+
52
+ # Set Hugging Face token (Keep this - best to use environment variable)
53
+ HF_TOKEN = os.getenv("HF_TOKEN")
54
+
55
+ # Logging Setup (Keep this - helpful for debugging)
56
  logging.basicConfig(
57
+ filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path
58
  level=logging.INFO,
59
  format="%(asctime)s - %(levelname)s - %(message)s",
60
  )
 
62
 
63
  # Model initialization
64
  model = None
65
+ model_initialization_error = "" # Global variable for initialization error
66
 
67
  def initialize_model():
68
+ """
69
+ Initialize the sentence transformer model with explicit cache path and error handling.
70
+ Returns:
71
+ bool: Whether the model was successfully initialized.
72
+ str: Error message if initialization failed, otherwise empty string.
73
+ """
74
  global model, model_initialization_error
75
  try:
76
  if model is None:
77
+ model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code)
78
+ os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists
79
+ # Use the HF_TOKEN to load the model (as in worked code)
80
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN)
81
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
82
  model_initialization_error = "" # Clear any previous error
83
  return True, "" # Return success and no error message
84
  return True, "" # Already initialized, return success and no error
85
+ except requests.exceptions.RequestException as e: # Specific network error handling (from worked code)
86
+ error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}"
87
+ logger.error(error_msg)
88
+ model_initialization_error = error_msg
89
+ return False, error_msg
90
+ except Exception as e: # General error handling (from worked code)
91
+ error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}"
92
  logger.error(error_msg)
93
+ model_initialization_error = error_msg
94
+ return False, error_msg
95
 
96
 
97
+ @spaces.GPU
98
  def generate_embedding(text, focus):
99
  global model, model_initialization_error
100
  if model is None:
 
113
  logger.error(error_msg)
114
  return "", error_msg
115
 
116
+ @spaces.GPU
117
  def save_embedding(embedding_json, name): # Expect JSON string as input from UI
118
  try:
119
  embedding = json.loads(embedding_json) # Parse JSON string back to list
120
+ filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath
121
  np.save(filepath, np.array(embedding))
122
  return f"Embedding saved to: {filepath}" # Return filepath in status
123
  except Exception as e:
 
125
  logger.error(error_msg)
126
  return error_msg
127
 
128
+ @spaces.GPU
129
  def convert_to_json(embedding_json, name): # Expect JSON string as input
130
  try:
131
+ filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath
132
  with open(filepath, "w") as f:
133
  f.write(embedding_json) # Directly write the JSON string
134
  return f"Embedding saved as JSON to: {filepath}" # Return filepath in status
 
137
  logger.error(error_msg)
138
  return error_msg
139
 
140
+ @spaces.GPU
141
  def process_files(files, focus):
142
  global model, model_initialization_error
143
  if model is None:
 
150
  file_statuses = [] # To track status for each file
151
  for file in files:
152
  try:
153
+ with open(file.name, 'rb') as f:
154
  text = f.read()
155
  with torch.amp.autocast('cuda'):
156
  embedding = model.encode([text])[0].tolist()
 
230
  )
231
 
232
  download_button.click(
233
+ lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join
234
  inputs=[save_name_input],
235
  outputs=[download_output]
236
  )