Spaces:
Sleeping
Sleeping
import json | |
import os | |
import gradio as gr | |
import logging | |
import traceback | |
import spaces | |
from typing import Optional, List | |
from dataclasses import dataclass | |
from datetime import datetime | |
from pathlib import Path | |
import gc | |
import torch | |
from torch.amp import autocast | |
from transformers import AutoModel, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import requests | |
from charset_normalizer import from_bytes | |
import zipfile | |
import tempfile | |
import shutil | |
# Custom Exception Class (Keep this) | |
class GPUQuotaExceededError(Exception): | |
pass | |
# Constants (Modified Persistent Paths and Cache) | |
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
CHUNK_SIZE = 500 | |
BATCH_SIZE = 32 | |
# Set Persistent Storage Path (More Explicit Paths - from Worked Code) | |
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage | |
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777) | |
# Define Subdirectories (More Explicit Paths) | |
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp") | |
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777) | |
OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs") | |
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777) | |
NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache") | |
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777) | |
LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs")) | |
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777) | |
# Set Hugging Face cache directory to persistent storage (From Worked Code - Important!) | |
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface") | |
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777) | |
# Set Hugging Face token (Keep this - best to use environment variable) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Logging Setup (Keep this - helpful for debugging) | |
logging.basicConfig( | |
filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
# Model initialization | |
model = None | |
model_initialization_error = "" # Global variable for initialization error | |
def initialize_model(): | |
""" | |
Initialize the sentence transformer model with explicit cache path and error handling. | |
Returns: | |
bool: Whether the model was successfully initialized. | |
str: Error message if initialization failed, otherwise empty string. | |
""" | |
global model, model_initialization_error | |
try: | |
if model is None: | |
model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code) | |
os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists | |
# Use the HF_TOKEN to load the model (as in worked code) | |
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN) | |
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}") | |
model_initialization_error = "" # Clear any previous error | |
return True, "" # Return success and no error message | |
return True, "" # Already initialized, return success and no error | |
except requests.exceptions.RequestException as e: # Specific network error handling (from worked code) | |
error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
model_initialization_error = error_msg | |
return False, error_msg | |
except Exception as e: # General error handling (from worked code) | |
error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
model_initialization_error = error_msg | |
return False, error_msg | |
def generate_embedding(text, focus): | |
global model, model_initialization_error | |
if model is None: | |
success, error_message = initialize_model() # Call initialize_model and get status | |
if not success: | |
return "", error_message # Return initialization error to UI | |
try: | |
with torch.amp.autocast('cuda'): | |
embedding_vector = model.encode([text])[0].tolist() # Get embedding as list | |
# Convert embedding to JSON string for direct display in UI | |
embedding_json_str = json.dumps(embedding_vector) | |
return embedding_json_str, "" # Return JSON string to UI | |
except Exception as e: | |
error_msg = f"Error generating embedding: {str(e)}" | |
logger.error(error_msg) | |
return "", error_msg | |
def save_embedding(embedding_json, name): # Expect JSON string as input from UI | |
try: | |
embedding = json.loads(embedding_json) # Parse JSON string back to list | |
filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath | |
np.save(filepath, np.array(embedding)) | |
return f"Embedding saved to: {filepath}" # Return filepath in status | |
except Exception as e: | |
error_msg = f"Error saving embedding: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def convert_to_json(embedding_json, name): # Expect JSON string as input | |
try: | |
filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath | |
with open(filepath, "w") as f: | |
f.write(embedding_json) # Directly write the JSON string | |
return f"Embedding saved as JSON to: {filepath}" # Return filepath in status | |
except Exception as e: | |
error_msg = f"Error converting to JSON: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def process_files(files, focus): | |
global model, model_initialization_error | |
if model is None: | |
success, error_message = initialize_model() # Call initialize_model and get status | |
if not success: | |
return "", error_message # Return initialization error to UI | |
try: | |
all_embeddings = [] | |
file_statuses = [] # To track status for each file | |
for file in files: | |
try: | |
with open(file.name, 'rb') as f: | |
text = f.read() | |
with torch.amp.autocast('cuda'): | |
embedding = model.encode([text])[0].tolist() | |
all_embeddings.append(embedding) | |
file_statuses.append(f"File '{file.name}' processed successfully.") | |
except Exception as file_e: | |
error_msg = f"Error processing file '{file.name}': {str(file_e)}" | |
logger.error(error_msg) | |
file_statuses.append(error_msg) | |
# Prepare status message for all files | |
status_message = "\n".join(file_statuses) | |
# Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files) | |
all_embeddings_json = json.dumps(all_embeddings) | |
return all_embeddings_json, status_message # Return JSON string and status message | |
except Exception as e: | |
error_msg = f"Error in process_files function: {str(e)}" | |
logger.error(error_msg) | |
return "", error_msg | |
def create_gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## Text Embedding Generator") | |
initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error | |
with gr.Row(): | |
text_input = gr.Textbox(label="Enter Text") | |
focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)") | |
with gr.Row(): | |
file_input = gr.File(label="Upload Files", file_count="multiple") | |
generate_button = gr.Button("Generate Embedding") | |
embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON | |
status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box | |
with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options | |
save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)") | |
with gr.Row(): | |
save_button = gr.Button("Save as .npy") | |
convert_button = gr.Button("Save as .json") | |
with gr.Row(): | |
save_status = gr.Textbox(label="Save Status") | |
convert_status = gr.Textbox(label="Convert Status") | |
download_button = gr.Button("Download JSON") | |
download_output = gr.File(label="Download JSON File") | |
process_button = gr.Button("Process Files") | |
process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output | |
process_status = gr.Textbox(label="File Processing Status") # Status for file processing | |
demo.load( # Call initialize_model on app load | |
lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box | |
outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box | |
) | |
generate_button.click( | |
generate_embedding, | |
inputs=[text_input, focus_input], | |
outputs=[embedding_output, status_box] # Renamed error_box to status_box | |
) | |
save_button.click( | |
save_embedding, | |
inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string) | |
outputs=[save_status] | |
) | |
convert_button.click( | |
convert_to_json, | |
inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string) | |
outputs=[convert_status] | |
) | |
download_button.click( | |
lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join | |
inputs=[save_name_input], | |
outputs=[download_output] | |
) | |
process_button.click( | |
process_files, | |
inputs=[file_input, focus_input], | |
outputs=[process_output, process_status] # outputs for process_files | |
) | |
return demo | |
if __name__ == "__main__": | |
# Explicitly initialize the model at app startup and check for errors | |
initialization_success, initialization_error_message = initialize_model() | |
if not initialization_success: | |
print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors | |
demo = create_gradio_interface() | |
demo.launch(server_name="0.0.0.0", allowed_paths=["/data"]) |