import json import os import gradio as gr import logging import traceback import spaces from typing import Optional, List from dataclasses import dataclass from datetime import datetime from pathlib import Path import gc import torch from torch.amp import autocast from transformers import AutoModel, AutoTokenizer from sentence_transformers import SentenceTransformer import numpy as np import requests from charset_normalizer import from_bytes import zipfile import tempfile import shutil # Custom Exception Class (Keep this) class GPUQuotaExceededError(Exception): pass # Constants (Modified Persistent Paths and Cache) EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" CHUNK_SIZE = 500 BATCH_SIZE = 32 # Set Persistent Storage Path (More Explicit Paths - from Worked Code) PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777) # Define Subdirectories (More Explicit Paths) TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp") os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777) OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs") os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777) NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache") os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777) LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs")) os.makedirs(LOG_DIR, exist_ok=True, mode=0o777) # Set Hugging Face cache directory to persistent storage (From Worked Code - Important!) os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface") os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777) # Set Hugging Face token (Keep this - best to use environment variable) HF_TOKEN = os.getenv("HF_TOKEN") # Logging Setup (Keep this - helpful for debugging) logging.basicConfig( filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Model initialization model = None model_initialization_error = "" # Global variable for initialization error def initialize_model(): """ Initialize the sentence transformer model with explicit cache path and error handling. Returns: bool: Whether the model was successfully initialized. str: Error message if initialization failed, otherwise empty string. """ global model, model_initialization_error try: if model is None: model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code) os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists # Use the HF_TOKEN to load the model (as in worked code) model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN) logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}") model_initialization_error = "" # Clear any previous error return True, "" # Return success and no error message return True, "" # Already initialized, return success and no error except requests.exceptions.RequestException as e: # Specific network error handling (from worked code) error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) model_initialization_error = error_msg return False, error_msg except Exception as e: # General error handling (from worked code) error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) model_initialization_error = error_msg return False, error_msg @spaces.GPU def generate_embedding(text, focus): global model, model_initialization_error if model is None: success, error_message = initialize_model() # Call initialize_model and get status if not success: return "", error_message # Return initialization error to UI try: with torch.amp.autocast('cuda'): embedding_vector = model.encode([text])[0].tolist() # Get embedding as list # Convert embedding to JSON string for direct display in UI embedding_json_str = json.dumps(embedding_vector) return embedding_json_str, "" # Return JSON string to UI except Exception as e: error_msg = f"Error generating embedding: {str(e)}" logger.error(error_msg) return "", error_msg @spaces.GPU def save_embedding(embedding_json, name): # Expect JSON string as input from UI try: embedding = json.loads(embedding_json) # Parse JSON string back to list filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath np.save(filepath, np.array(embedding)) return f"Embedding saved to: {filepath}" # Return filepath in status except Exception as e: error_msg = f"Error saving embedding: {str(e)}" logger.error(error_msg) return error_msg @spaces.GPU def convert_to_json(embedding_json, name): # Expect JSON string as input try: filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath with open(filepath, "w") as f: f.write(embedding_json) # Directly write the JSON string return f"Embedding saved as JSON to: {filepath}" # Return filepath in status except Exception as e: error_msg = f"Error converting to JSON: {str(e)}" logger.error(error_msg) return error_msg @spaces.GPU def process_files(files, focus): global model, model_initialization_error if model is None: success, error_message = initialize_model() # Call initialize_model and get status if not success: return "", error_message # Return initialization error to UI try: all_embeddings = [] file_statuses = [] # To track status for each file for file in files: try: with open(file.name, 'rb') as f: text = f.read() with torch.amp.autocast('cuda'): embedding = model.encode([text])[0].tolist() all_embeddings.append(embedding) file_statuses.append(f"File '{file.name}' processed successfully.") except Exception as file_e: error_msg = f"Error processing file '{file.name}': {str(file_e)}" logger.error(error_msg) file_statuses.append(error_msg) # Prepare status message for all files status_message = "\n".join(file_statuses) # Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files) all_embeddings_json = json.dumps(all_embeddings) return all_embeddings_json, status_message # Return JSON string and status message except Exception as e: error_msg = f"Error in process_files function: {str(e)}" logger.error(error_msg) return "", error_msg def create_gradio_interface(): with gr.Blocks() as demo: gr.Markdown("## Text Embedding Generator") initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error with gr.Row(): text_input = gr.Textbox(label="Enter Text") focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)") with gr.Row(): file_input = gr.File(label="Upload Files", file_count="multiple") generate_button = gr.Button("Generate Embedding") embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)") with gr.Row(): save_button = gr.Button("Save as .npy") convert_button = gr.Button("Save as .json") with gr.Row(): save_status = gr.Textbox(label="Save Status") convert_status = gr.Textbox(label="Convert Status") download_button = gr.Button("Download JSON") download_output = gr.File(label="Download JSON File") process_button = gr.Button("Process Files") process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output process_status = gr.Textbox(label="File Processing Status") # Status for file processing demo.load( # Call initialize_model on app load lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box ) generate_button.click( generate_embedding, inputs=[text_input, focus_input], outputs=[embedding_output, status_box] # Renamed error_box to status_box ) save_button.click( save_embedding, inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string) outputs=[save_status] ) convert_button.click( convert_to_json, inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string) outputs=[convert_status] ) download_button.click( lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join inputs=[save_name_input], outputs=[download_output] ) process_button.click( process_files, inputs=[file_input, focus_input], outputs=[process_output, process_status] # outputs for process_files ) return demo if __name__ == "__main__": # Explicitly initialize the model at app startup and check for errors initialization_success, initialization_error_message = initialize_model() if not initialization_success: print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors demo = create_gradio_interface() demo.launch(server_name="0.0.0.0", allowed_paths=["/data"])