Emails2go / app.py
Ultronprime's picture
add secret
5fafef7
raw
history blame
11 kB
import os
import gradio as gr
import logging
import traceback
import spaces
from typing import Optional, List
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import gc
import torch
from torch.amp import autocast
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import requests
from charset_normalizer import from_bytes
import zipfile
import tempfile
import webbrowser
# Custom Exception Class
class GPUQuotaExceededError(Exception):
pass
# Constants
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 500
BATCH_SIZE = 32
# Set Persistent Storage Path
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data")
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)
# Define Subdirectories
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)
OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)
NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache")
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)
LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
# Set Hugging Face cache directory to persistent storage
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
# Set Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
# Logging Setup
logging.basicConfig(
filename=os.path.join(LOG_DIR, "app.log"),
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Model initialization
model = None
def initialize_model():
"""
Initialize the sentence transformer model.
Returns:
bool: Whether the model was successfully initialized.
"""
global model
try:
if model is None:
model_cache = os.path.join(PERSISTENT_PATH, "models")
os.makedirs(model_cache, exist_ok=True, mode=0o777)
# Use the HF_TOKEN to load the model
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN)
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
return True
except requests.exceptions.RequestException as e:
logger.error(f"Connection error during model download: {str(e)}\n{traceback.format_exc()}")
return False
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}\n{traceback.format_exc()}")
return False
@spaces.GPU
def handle_gpu_operation(func):
try:
start_time = datetime.now()
# Updated autocast usage as per deprecation notice
with autocast(device_type='cuda', dtype=torch.float16):
result = func()
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
logger.info(f"GPU operation completed in {duration:.2f}s")
return result
except RuntimeError as e:
if "CUDA out of memory" in str(e):
torch.cuda.empty_cache()
logger.error(f"GPU memory error: {str(e)}")
raise GPUQuotaExceededError("GPU memory limit exceeded. Please try with a smaller batch.")
else:
logger.error(f"GPU runtime error: {str(e)}")
raise
except Exception as e:
if "quota exceeded" in str(e).lower():
logger.error(f"GPU quota exceeded: {str(e)}")
raise GPUQuotaExceededError("GPU quota exceeded. Please wait a few minutes before trying again.")
else:
logger.error(f"Unexpected GPU error: {str(e)}")
raise
def get_model():
global model
if model is None:
if torch.cuda.is_available():
initialize_model()
else:
logger.warning("Attempted to initialize model outside GPU context, deferring.")
return None
return model
@spaces.GPU
def process_files(files):
if not files:
return "Please upload one or more.txt files.", "", ""
try:
if not initialize_model():
return "Failed to initialize the model. Please try again.", "", ""
valid_files = [f for f in files if f.name.lower().endswith('.txt')]
if not valid_files:
return "No.txt files found. Please upload valid.txt files.", "", ""
all_chunks = []
processed_files = 0
for file in valid_files:
try:
with open(file.name, 'rb') as f:
content = f.read()
detected_encoding = from_bytes(content).best().encoding
decoded_content = content.decode(detected_encoding, errors='ignore')
# Split content into chunks
chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)]
all_chunks.extend(chunks)
processed_files += 1
logger.info(f"Processed file: {file.name}")
except Exception as e:
logger.error(f"Error processing file {file.name}: {str(e)}")
if not all_chunks:
return "No valid content found in the uploaded files.", "", ""
# Generate embeddings in batches
all_embeddings = []
for i in range(0, len(all_chunks), BATCH_SIZE):
batch = all_chunks[i:i+BATCH_SIZE]
if model:
embeddings = handle_gpu_operation(lambda: model.encode(batch))
all_embeddings.extend(embeddings)
else:
return "Model not initialized. Please check model initialization.", "", ""
# Save results to OUTPUTS_DIR
embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy")
np.save(embeddings_path, np.array(all_embeddings))
chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt")
with open(chunks_path, "w", encoding="utf-8") as f:
for chunk in all_chunks:
f.write(chunk + "\n===CHUNK_SEPARATOR===\n")
return (
f"Successfully processed {processed_files} files. Generated {len(all_embeddings)} embeddings from {len(all_chunks)} chunks.",
"",
""
)
except Exception as e:
logger.error(f"Processing failed: {str(e)}")
return f"Error processing files: {str(e)}", "", ""
@spaces.GPU
def semantic_search(query, top_k=5):
global model
if model is None:
return "Model not initialized. Please process files first."
try:
# Load saved embeddings and chunks from OUTPUTS_DIR
embeddings_file = os.path.join(OUTPUTS_DIR, "embeddings.npy")
chunks_file = os.path.join(OUTPUTS_DIR, "chunks.txt")
stored_embeddings = np.load(embeddings_file)
with open(chunks_file, "r", encoding="utf-8") as f:
chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
chunks = [c for c in chunks if c.strip()]
# Get query embedding
query_embedding = model.encode([query])[0]
# Calculate similarities
similarities = np.dot(stored_embeddings, query_embedding) / (
np.linalg.norm(stored_embeddings, axis=1) * np.linalg.norm(query_embedding)
)
# Get top results
top_indices = np.argsort(similarities)[-top_k:][::-1]
results = []
for idx in top_indices:
results.append(f"""
Similarity: {similarities[idx]:.3f}
Content: {chunks[idx]}
-------------------
""")
return "\n".join(results)
except Exception as e:
logger.error(f"Search error: {str(e)}")
return f"Search error occurred: {str(e)}"
def search_and_format(query, num_results):
if not query.strip():
return "Please enter a search query"
return semantic_search(query, top_k=num_results)
def browse_outputs():
try:
# Open the outputs directory in a web browser (may work on some systems)
webbrowser.open(f"file://{OUTPUTS_DIR}")
return "Opened outputs directory."
except Exception as e:
logger.error(f"Error opening file browser: {str(e)}")
return "Error opening file browser."
def download_results():
required_files = ["embeddings.npy", "chunks.txt"]
missing = [f for f in required_files if not os.path.exists(os.path.join(OUTPUTS_DIR, f))]
if missing:
logger.error(f"Missing files: {missing}")
return None
try:
zip_path = os.path.join(OUTPUTS_DIR, "results.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for file in required_files:
file_path = os.path.join(OUTPUTS_DIR, file)
zipf.write(file_path, file)
return zip_path
except Exception as e:
logger.error(f"Error creating download archive: {str(e)}")
return None
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("## Text Chunk Embeddings Generator")
error_box = gr.Textbox(visible=False, label="Status/Error Messages")
with gr.Row():
file_input = gr.File(
label="Upload Text Files",
file_count="multiple",
file_types=[".txt"]
)
process_button = gr.Button("Generate Embeddings")
output_text = gr.Textbox(label="Status")
process_button.click(
fn=process_files,
inputs=[file_input],
outputs=[output_text, error_box, error_box]
)
with gr.Tab("Search"):
query_input = gr.Textbox(
label="Enter your search query",
placeholder="Enter text to search through your documents..."
)
top_k_slider = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Number of results to return"
)
search_button = gr.Button(" Search")
results_output = gr.Textbox(
label="Search Results",
lines=10,
show_copy_button=True
)
search_button.click(
fn=search_and_format,
inputs=[query_input, top_k_slider],
outputs=results_output
)
download_button = gr.Button(" Download Results")
download_button.click(
fn=download_results,
outputs=[gr.File(label="Download Results")]
)
with gr.Tab("Outputs"):
browse_button = gr.Button(" Browse Outputs")
browse_button.click(
fn=browse_outputs,
outputs=[gr.Textbox(label="Browse Status")]
)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0")