import os import gradio as gr import logging import traceback import spaces from typing import Optional, List from dataclasses import dataclass from datetime import datetime from pathlib import Path import gc import torch from torch.amp import autocast from transformers import AutoModel, AutoTokenizer from sentence_transformers import SentenceTransformer import numpy as np import requests from charset_normalizer import from_bytes import zipfile import tempfile import webbrowser # Custom Exception Class class GPUQuotaExceededError(Exception): pass # Constants EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" CHUNK_SIZE = 500 BATCH_SIZE = 32 # Set Persistent Storage Path PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777) # Define Subdirectories TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp") os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777) OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs") os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777) NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache") os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777) LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs")) os.makedirs(LOG_DIR, exist_ok=True, mode=0o777) # Set Hugging Face cache directory to persistent storage os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface") os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777) # Set Hugging Face token HF_TOKEN = os.getenv("HF_TOKEN") # Logging Setup logging.basicConfig( filename=os.path.join(LOG_DIR, "app.log"), level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Model initialization model = None def initialize_model(): """ Initialize the sentence transformer model. Returns: bool: Whether the model was successfully initialized. """ global model try: if model is None: model_cache = os.path.join(PERSISTENT_PATH, "models") os.makedirs(model_cache, exist_ok=True, mode=0o777) # Use the HF_TOKEN to load the model model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN) logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}") return True except requests.exceptions.RequestException as e: logger.error(f"Connection error during model download: {str(e)}\n{traceback.format_exc()}") return False except Exception as e: logger.error(f"Model initialization failed: {str(e)}\n{traceback.format_exc()}") return False @spaces.GPU def handle_gpu_operation(func): try: start_time = datetime.now() # Updated autocast usage as per deprecation notice with autocast(device_type='cuda', dtype=torch.float16): result = func() end_time = datetime.now() duration = (end_time - start_time).total_seconds() logger.info(f"GPU operation completed in {duration:.2f}s") return result except RuntimeError as e: if "CUDA out of memory" in str(e): torch.cuda.empty_cache() logger.error(f"GPU memory error: {str(e)}") raise GPUQuotaExceededError("GPU memory limit exceeded. Please try with a smaller batch.") else: logger.error(f"GPU runtime error: {str(e)}") raise except Exception as e: if "quota exceeded" in str(e).lower(): logger.error(f"GPU quota exceeded: {str(e)}") raise GPUQuotaExceededError("GPU quota exceeded. Please wait a few minutes before trying again.") else: logger.error(f"Unexpected GPU error: {str(e)}") raise def get_model(): global model if model is None: if torch.cuda.is_available(): initialize_model() else: logger.warning("Attempted to initialize model outside GPU context, deferring.") return None return model @spaces.GPU def process_files(files): if not files: return "Please upload one or more.txt files.", "", "" try: if not initialize_model(): return "Failed to initialize the model. Please try again.", "", "" valid_files = [f for f in files if f.name.lower().endswith('.txt')] if not valid_files: return "No.txt files found. Please upload valid.txt files.", "", "" all_chunks = [] processed_files = 0 for file in valid_files: try: with open(file.name, 'rb') as f: content = f.read() detected_encoding = from_bytes(content).best().encoding decoded_content = content.decode(detected_encoding, errors='ignore') # Split content into chunks chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)] all_chunks.extend(chunks) processed_files += 1 logger.info(f"Processed file: {file.name}") except Exception as e: logger.error(f"Error processing file {file.name}: {str(e)}") if not all_chunks: return "No valid content found in the uploaded files.", "", "" # Generate embeddings in batches all_embeddings = [] for i in range(0, len(all_chunks), BATCH_SIZE): batch = all_chunks[i:i+BATCH_SIZE] if model: embeddings = handle_gpu_operation(lambda: model.encode(batch)) all_embeddings.extend(embeddings) else: return "Model not initialized. Please check model initialization.", "", "" # Save results to OUTPUTS_DIR embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy") np.save(embeddings_path, np.array(all_embeddings)) chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt") with open(chunks_path, "w", encoding="utf-8") as f: for chunk in all_chunks: f.write(chunk + "\n===CHUNK_SEPARATOR===\n") return ( f"Successfully processed {processed_files} files. Generated {len(all_embeddings)} embeddings from {len(all_chunks)} chunks.", "", "" ) except Exception as e: logger.error(f"Processing failed: {str(e)}") return f"Error processing files: {str(e)}", "", "" @spaces.GPU def semantic_search(query, top_k=5): global model if model is None: return "Model not initialized. Please process files first." try: # Load saved embeddings and chunks from OUTPUTS_DIR embeddings_file = os.path.join(OUTPUTS_DIR, "embeddings.npy") chunks_file = os.path.join(OUTPUTS_DIR, "chunks.txt") stored_embeddings = np.load(embeddings_file) with open(chunks_file, "r", encoding="utf-8") as f: chunks = f.read().split("\n===CHUNK_SEPARATOR===\n") chunks = [c for c in chunks if c.strip()] # Get query embedding query_embedding = model.encode([query])[0] # Calculate similarities similarities = np.dot(stored_embeddings, query_embedding) / ( np.linalg.norm(stored_embeddings, axis=1) * np.linalg.norm(query_embedding) ) # Get top results top_indices = np.argsort(similarities)[-top_k:][::-1] results = [] for idx in top_indices: results.append(f""" Similarity: {similarities[idx]:.3f} Content: {chunks[idx]} ------------------- """) return "\n".join(results) except Exception as e: logger.error(f"Search error: {str(e)}") return f"Search error occurred: {str(e)}" def search_and_format(query, num_results): if not query.strip(): return "Please enter a search query" return semantic_search(query, top_k=num_results) def browse_outputs(): try: # Open the outputs directory in a web browser (may work on some systems) webbrowser.open(f"file://{OUTPUTS_DIR}") return "Opened outputs directory." except Exception as e: logger.error(f"Error opening file browser: {str(e)}") return "Error opening file browser." def download_results(): required_files = ["embeddings.npy", "chunks.txt"] missing = [f for f in required_files if not os.path.exists(os.path.join(OUTPUTS_DIR, f))] if missing: logger.error(f"Missing files: {missing}") return None try: zip_path = os.path.join(OUTPUTS_DIR, "results.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: for file in required_files: file_path = os.path.join(OUTPUTS_DIR, file) zipf.write(file_path, file) return zip_path except Exception as e: logger.error(f"Error creating download archive: {str(e)}") return None def create_gradio_interface(): with gr.Blocks() as demo: gr.Markdown("## Text Chunk Embeddings Generator") error_box = gr.Textbox(visible=False, label="Status/Error Messages") with gr.Row(): file_input = gr.File( label="Upload Text Files", file_count="multiple", file_types=[".txt"] ) process_button = gr.Button("Generate Embeddings") output_text = gr.Textbox(label="Status") process_button.click( fn=process_files, inputs=[file_input], outputs=[output_text, error_box, error_box] ) with gr.Tab("Search"): query_input = gr.Textbox( label="Enter your search query", placeholder="Enter text to search through your documents..." ) top_k_slider = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of results to return" ) search_button = gr.Button(" Search") results_output = gr.Textbox( label="Search Results", lines=10, show_copy_button=True ) search_button.click( fn=search_and_format, inputs=[query_input, top_k_slider], outputs=results_output ) download_button = gr.Button(" Download Results") download_button.click( fn=download_results, outputs=[gr.File(label="Download Results")] ) with gr.Tab("Outputs"): browse_button = gr.Button(" Browse Outputs") browse_button.click( fn=browse_outputs, outputs=[gr.Textbox(label="Browse Status")] ) return demo if __name__ == "__main__": demo = create_gradio_interface() demo.launch(server_name="0.0.0.0")