Spaces:
Sleeping
Sleeping
File size: 10,717 Bytes
131a7a0 bd0e3b7 1b5d6e1 d1cdc5f 9c81028 4be0978 d1cdc5f 1b5d6e1 d1cdc5f 1b5d6e1 d1cdc5f 160e875 d1cdc5f 5fafef7 d1cdc5f 9c81028 d1cdc5f 1b5d6e1 d1cdc5f 1b5d6e1 d1cdc5f 1b5d6e1 d1cdc5f 4be0978 1b5d6e1 d1cdc5f 1b5d6e1 4be0978 d1cdc5f 4be0978 d1cdc5f 4be0978 3618705 d1cdc5f 9c81028 4be0978 9c74ac0 4be0978 3618705 25e3a1a 4be0978 9c74ac0 25e3a1a 9c81028 5fafef7 d1cdc5f 9c74ac0 1b5d6e1 9c74ac0 d1cdc5f 9c74ac0 9c81028 25e3a1a d1cdc5f 9c74ac0 9c81028 d1cdc5f 9c74ac0 25e3a1a 9c81028 1b5d6e1 d1cdc5f 9c81028 4be0978 9c74ac0 4be0978 1b5d6e1 3618705 9c81028 9c74ac0 9c81028 9c74ac0 d1cdc5f 9c74ac0 4be0978 9c74ac0 d4cee85 9c74ac0 9c81028 1b5d6e1 9c74ac0 1b5d6e1 9c81028 1b5d6e1 4be0978 9c81028 1b5d6e1 9c81028 1b5d6e1 9c81028 9c74ac0 9c81028 9c74ac0 4be0978 9c81028 9c74ac0 bb01969 9c81028 9c74ac0 9c81028 9c74ac0 9c81028 1b5d6e1 9c81028 d1cdc5f 9c81028 9c74ac0 f57d5e5 1b5d6e1 4be0978 1b5d6e1 3417951 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import json
import os
import gradio as gr
import logging
import traceback
import spaces
from typing import Optional, List
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import gc
import torch
from torch.amp import autocast
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import requests
from charset_normalizer import from_bytes
import zipfile
import tempfile
import shutil
# Custom Exception Class (Keep this)
class GPUQuotaExceededError(Exception):
pass
# Constants (Modified Persistent Paths and Cache)
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 500
BATCH_SIZE = 32
# Set Persistent Storage Path (More Explicit Paths - from Worked Code)
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)
# Define Subdirectories (More Explicit Paths)
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)
OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)
NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache")
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)
LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
# Set Hugging Face cache directory to persistent storage (From Worked Code - Important!)
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
# Set Hugging Face token (Keep this - best to use environment variable)
HF_TOKEN = os.getenv("HF_TOKEN")
# Logging Setup (Keep this - helpful for debugging)
logging.basicConfig(
filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Model initialization
model = None
model_initialization_error = "" # Global variable for initialization error
def initialize_model():
"""
Initialize the sentence transformer model with explicit cache path and error handling.
Returns:
bool: Whether the model was successfully initialized.
str: Error message if initialization failed, otherwise empty string.
"""
global model, model_initialization_error
try:
if model is None:
model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code)
os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists
# Use the HF_TOKEN to load the model (as in worked code)
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN)
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
model_initialization_error = "" # Clear any previous error
return True, "" # Return success and no error message
return True, "" # Already initialized, return success and no error
except requests.exceptions.RequestException as e: # Specific network error handling (from worked code)
error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
model_initialization_error = error_msg
return False, error_msg
except Exception as e: # General error handling (from worked code)
error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
model_initialization_error = error_msg
return False, error_msg
@spaces.GPU
def generate_embedding(text, focus):
global model, model_initialization_error
if model is None:
success, error_message = initialize_model() # Call initialize_model and get status
if not success:
return "", error_message # Return initialization error to UI
try:
with torch.amp.autocast('cuda'):
embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
# Convert embedding to JSON string for direct display in UI
embedding_json_str = json.dumps(embedding_vector)
return embedding_json_str, "" # Return JSON string to UI
except Exception as e:
error_msg = f"Error generating embedding: {str(e)}"
logger.error(error_msg)
return "", error_msg
@spaces.GPU
def save_embedding(embedding_json, name): # Expect JSON string as input from UI
try:
embedding = json.loads(embedding_json) # Parse JSON string back to list
filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath
np.save(filepath, np.array(embedding))
return f"Embedding saved to: {filepath}" # Return filepath in status
except Exception as e:
error_msg = f"Error saving embedding: {str(e)}"
logger.error(error_msg)
return error_msg
@spaces.GPU
def convert_to_json(embedding_json, name): # Expect JSON string as input
try:
filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath
with open(filepath, "w") as f:
f.write(embedding_json) # Directly write the JSON string
return f"Embedding saved as JSON to: {filepath}" # Return filepath in status
except Exception as e:
error_msg = f"Error converting to JSON: {str(e)}"
logger.error(error_msg)
return error_msg
@spaces.GPU
def process_files(files, focus):
global model, model_initialization_error
if model is None:
success, error_message = initialize_model() # Call initialize_model and get status
if not success:
return "", error_message # Return initialization error to UI
try:
all_embeddings = []
file_statuses = [] # To track status for each file
for file in files:
try:
with open(file.name, 'rb') as f:
text = f.read()
with torch.amp.autocast('cuda'):
embedding = model.encode([text])[0].tolist()
all_embeddings.append(embedding)
file_statuses.append(f"File '{file.name}' processed successfully.")
except Exception as file_e:
error_msg = f"Error processing file '{file.name}': {str(file_e)}"
logger.error(error_msg)
file_statuses.append(error_msg)
# Prepare status message for all files
status_message = "\n".join(file_statuses)
# Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files)
all_embeddings_json = json.dumps(all_embeddings)
return all_embeddings_json, status_message # Return JSON string and status message
except Exception as e:
error_msg = f"Error in process_files function: {str(e)}"
logger.error(error_msg)
return "", error_msg
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("## Text Embedding Generator")
initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error
with gr.Row():
text_input = gr.Textbox(label="Enter Text")
focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")
with gr.Row():
file_input = gr.File(label="Upload Files", file_count="multiple")
generate_button = gr.Button("Generate Embedding")
embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON
status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box
with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options
save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)")
with gr.Row():
save_button = gr.Button("Save as .npy")
convert_button = gr.Button("Save as .json")
with gr.Row():
save_status = gr.Textbox(label="Save Status")
convert_status = gr.Textbox(label="Convert Status")
download_button = gr.Button("Download JSON")
download_output = gr.File(label="Download JSON File")
process_button = gr.Button("Process Files")
process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
process_status = gr.Textbox(label="File Processing Status") # Status for file processing
demo.load( # Call initialize_model on app load
lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box
outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box
)
generate_button.click(
generate_embedding,
inputs=[text_input, focus_input],
outputs=[embedding_output, status_box] # Renamed error_box to status_box
)
save_button.click(
save_embedding,
inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string)
outputs=[save_status]
)
convert_button.click(
convert_to_json,
inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string)
outputs=[convert_status]
)
download_button.click(
lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join
inputs=[save_name_input],
outputs=[download_output]
)
process_button.click(
process_files,
inputs=[file_input, focus_input],
outputs=[process_output, process_status] # outputs for process_files
)
return demo
if __name__ == "__main__":
# Explicitly initialize the model at app startup and check for errors
initialization_success, initialization_error_message = initialize_model()
if not initialization_success:
print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", allowed_paths=["/data"]) |