Emails2go / app.py
Ultronprime's picture
Update app.py
3417951 verified
import json
import os
import gradio as gr
import logging
import traceback
import spaces
from typing import Optional, List
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import gc
import torch
from torch.amp import autocast
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import requests
from charset_normalizer import from_bytes
import zipfile
import tempfile
import shutil
# Custom Exception Class (Keep this)
class GPUQuotaExceededError(Exception):
pass
# Constants (Modified Persistent Paths and Cache)
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 500
BATCH_SIZE = 32
# Set Persistent Storage Path (More Explicit Paths - from Worked Code)
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)
# Define Subdirectories (More Explicit Paths)
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)
OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)
NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache")
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)
LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
# Set Hugging Face cache directory to persistent storage (From Worked Code - Important!)
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
# Set Hugging Face token (Keep this - best to use environment variable)
HF_TOKEN = os.getenv("HF_TOKEN")
# Logging Setup (Keep this - helpful for debugging)
logging.basicConfig(
filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Model initialization
model = None
model_initialization_error = "" # Global variable for initialization error
def initialize_model():
"""
Initialize the sentence transformer model with explicit cache path and error handling.
Returns:
bool: Whether the model was successfully initialized.
str: Error message if initialization failed, otherwise empty string.
"""
global model, model_initialization_error
try:
if model is None:
model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code)
os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists
# Use the HF_TOKEN to load the model (as in worked code)
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN)
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
model_initialization_error = "" # Clear any previous error
return True, "" # Return success and no error message
return True, "" # Already initialized, return success and no error
except requests.exceptions.RequestException as e: # Specific network error handling (from worked code)
error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
model_initialization_error = error_msg
return False, error_msg
except Exception as e: # General error handling (from worked code)
error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
model_initialization_error = error_msg
return False, error_msg
@spaces.GPU
def generate_embedding(text, focus):
global model, model_initialization_error
if model is None:
success, error_message = initialize_model() # Call initialize_model and get status
if not success:
return "", error_message # Return initialization error to UI
try:
with torch.amp.autocast('cuda'):
embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
# Convert embedding to JSON string for direct display in UI
embedding_json_str = json.dumps(embedding_vector)
return embedding_json_str, "" # Return JSON string to UI
except Exception as e:
error_msg = f"Error generating embedding: {str(e)}"
logger.error(error_msg)
return "", error_msg
@spaces.GPU
def save_embedding(embedding_json, name): # Expect JSON string as input from UI
try:
embedding = json.loads(embedding_json) # Parse JSON string back to list
filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath
np.save(filepath, np.array(embedding))
return f"Embedding saved to: {filepath}" # Return filepath in status
except Exception as e:
error_msg = f"Error saving embedding: {str(e)}"
logger.error(error_msg)
return error_msg
@spaces.GPU
def convert_to_json(embedding_json, name): # Expect JSON string as input
try:
filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath
with open(filepath, "w") as f:
f.write(embedding_json) # Directly write the JSON string
return f"Embedding saved as JSON to: {filepath}" # Return filepath in status
except Exception as e:
error_msg = f"Error converting to JSON: {str(e)}"
logger.error(error_msg)
return error_msg
@spaces.GPU
def process_files(files, focus):
global model, model_initialization_error
if model is None:
success, error_message = initialize_model() # Call initialize_model and get status
if not success:
return "", error_message # Return initialization error to UI
try:
all_embeddings = []
file_statuses = [] # To track status for each file
for file in files:
try:
with open(file.name, 'rb') as f:
text = f.read()
with torch.amp.autocast('cuda'):
embedding = model.encode([text])[0].tolist()
all_embeddings.append(embedding)
file_statuses.append(f"File '{file.name}' processed successfully.")
except Exception as file_e:
error_msg = f"Error processing file '{file.name}': {str(file_e)}"
logger.error(error_msg)
file_statuses.append(error_msg)
# Prepare status message for all files
status_message = "\n".join(file_statuses)
# Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files)
all_embeddings_json = json.dumps(all_embeddings)
return all_embeddings_json, status_message # Return JSON string and status message
except Exception as e:
error_msg = f"Error in process_files function: {str(e)}"
logger.error(error_msg)
return "", error_msg
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("## Text Embedding Generator")
initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error
with gr.Row():
text_input = gr.Textbox(label="Enter Text")
focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")
with gr.Row():
file_input = gr.File(label="Upload Files", file_count="multiple")
generate_button = gr.Button("Generate Embedding")
embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON
status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box
with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options
save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)")
with gr.Row():
save_button = gr.Button("Save as .npy")
convert_button = gr.Button("Save as .json")
with gr.Row():
save_status = gr.Textbox(label="Save Status")
convert_status = gr.Textbox(label="Convert Status")
download_button = gr.Button("Download JSON")
download_output = gr.File(label="Download JSON File")
process_button = gr.Button("Process Files")
process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
process_status = gr.Textbox(label="File Processing Status") # Status for file processing
demo.load( # Call initialize_model on app load
lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box
outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box
)
generate_button.click(
generate_embedding,
inputs=[text_input, focus_input],
outputs=[embedding_output, status_box] # Renamed error_box to status_box
)
save_button.click(
save_embedding,
inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string)
outputs=[save_status]
)
convert_button.click(
convert_to_json,
inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string)
outputs=[convert_status]
)
download_button.click(
lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join
inputs=[save_name_input],
outputs=[download_output]
)
process_button.click(
process_files,
inputs=[file_input, focus_input],
outputs=[process_output, process_status] # outputs for process_files
)
return demo
if __name__ == "__main__":
# Explicitly initialize the model at app startup and check for errors
initialization_success, initialization_error_message = initialize_model()
if not initialization_success:
print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", allowed_paths=["/data"])