Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import logging | |
import traceback | |
import spaces | |
from typing import Optional, List | |
from dataclasses import dataclass | |
from datetime import datetime | |
from pathlib import Path | |
import gc | |
import torch | |
from torch.amp import autocast | |
from transformers import AutoModel, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import requests | |
from charset_normalizer import from_bytes | |
import zipfile | |
import tempfile | |
import webbrowser | |
# Custom Exception Class | |
class GPUQuotaExceededError(Exception): | |
pass | |
# Constants | |
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
CHUNK_SIZE = 500 | |
BATCH_SIZE = 32 | |
# Set Persistent Storage Path | |
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") | |
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777) | |
# Define Subdirectories | |
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp") | |
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777) | |
OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs") | |
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777) | |
NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache") | |
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777) | |
LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs")) | |
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777) | |
# Set Hugging Face cache directory to persistent storage | |
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface") | |
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777) | |
# Set Hugging Face token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Logging Setup | |
logging.basicConfig( | |
filename=os.path.join(LOG_DIR, "app.log"), | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
# Model initialization | |
model = None | |
def initialize_model(): | |
""" | |
Initialize the sentence transformer model. | |
Returns: | |
bool: Whether the model was successfully initialized. | |
""" | |
global model | |
try: | |
if model is None: | |
model_cache = os.path.join(PERSISTENT_PATH, "models") | |
os.makedirs(model_cache, exist_ok=True, mode=0o777) | |
# Use the HF_TOKEN to load the model | |
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN) | |
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}") | |
return True | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Connection error during model download: {str(e)}\n{traceback.format_exc()}") | |
return False | |
except Exception as e: | |
logger.error(f"Model initialization failed: {str(e)}\n{traceback.format_exc()}") | |
return False | |
def handle_gpu_operation(func): | |
try: | |
start_time = datetime.now() | |
# Updated autocast usage as per deprecation notice | |
with autocast(device_type='cuda', dtype=torch.float16): | |
result = func() | |
end_time = datetime.now() | |
duration = (end_time - start_time).total_seconds() | |
logger.info(f"GPU operation completed in {duration:.2f}s") | |
return result | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e): | |
torch.cuda.empty_cache() | |
logger.error(f"GPU memory error: {str(e)}") | |
raise GPUQuotaExceededError("GPU memory limit exceeded. Please try with a smaller batch.") | |
else: | |
logger.error(f"GPU runtime error: {str(e)}") | |
raise | |
except Exception as e: | |
if "quota exceeded" in str(e).lower(): | |
logger.error(f"GPU quota exceeded: {str(e)}") | |
raise GPUQuotaExceededError("GPU quota exceeded. Please wait a few minutes before trying again.") | |
else: | |
logger.error(f"Unexpected GPU error: {str(e)}") | |
raise | |
def get_model(): | |
global model | |
if model is None: | |
if torch.cuda.is_available(): | |
initialize_model() | |
else: | |
logger.warning("Attempted to initialize model outside GPU context, deferring.") | |
return None | |
return model | |
def process_files(files): | |
if not files: | |
return "Please upload one or more.txt files.", "", "" | |
try: | |
if not initialize_model(): | |
return "Failed to initialize the model. Please try again.", "", "" | |
valid_files = [f for f in files if f.name.lower().endswith('.txt')] | |
if not valid_files: | |
return "No.txt files found. Please upload valid.txt files.", "", "" | |
all_chunks = [] | |
processed_files = 0 | |
for file in valid_files: | |
try: | |
with open(file.name, 'rb') as f: | |
content = f.read() | |
detected_encoding = from_bytes(content).best().encoding | |
decoded_content = content.decode(detected_encoding, errors='ignore') | |
# Split content into chunks | |
chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)] | |
all_chunks.extend(chunks) | |
processed_files += 1 | |
logger.info(f"Processed file: {file.name}") | |
except Exception as e: | |
logger.error(f"Error processing file {file.name}: {str(e)}") | |
if not all_chunks: | |
return "No valid content found in the uploaded files.", "", "" | |
# Generate embeddings in batches | |
all_embeddings = [] | |
for i in range(0, len(all_chunks), BATCH_SIZE): | |
batch = all_chunks[i:i+BATCH_SIZE] | |
if model: | |
embeddings = handle_gpu_operation(lambda: model.encode(batch)) | |
all_embeddings.extend(embeddings) | |
else: | |
return "Model not initialized. Please check model initialization.", "", "" | |
# Save results to OUTPUTS_DIR | |
embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy") | |
np.save(embeddings_path, np.array(all_embeddings)) | |
chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt") | |
with open(chunks_path, "w", encoding="utf-8") as f: | |
for chunk in all_chunks: | |
f.write(chunk + "\n===CHUNK_SEPARATOR===\n") | |
return ( | |
f"Successfully processed {processed_files} files. Generated {len(all_embeddings)} embeddings from {len(all_chunks)} chunks.", | |
"", | |
"" | |
) | |
except Exception as e: | |
logger.error(f"Processing failed: {str(e)}") | |
return f"Error processing files: {str(e)}", "", "" | |
def semantic_search(query, top_k=5): | |
global model | |
if model is None: | |
return "Model not initialized. Please process files first." | |
try: | |
# Load saved embeddings and chunks from OUTPUTS_DIR | |
embeddings_file = os.path.join(OUTPUTS_DIR, "embeddings.npy") | |
chunks_file = os.path.join(OUTPUTS_DIR, "chunks.txt") | |
stored_embeddings = np.load(embeddings_file) | |
with open(chunks_file, "r", encoding="utf-8") as f: | |
chunks = f.read().split("\n===CHUNK_SEPARATOR===\n") | |
chunks = [c for c in chunks if c.strip()] | |
# Get query embedding | |
query_embedding = model.encode([query])[0] | |
# Calculate similarities | |
similarities = np.dot(stored_embeddings, query_embedding) / ( | |
np.linalg.norm(stored_embeddings, axis=1) * np.linalg.norm(query_embedding) | |
) | |
# Get top results | |
top_indices = np.argsort(similarities)[-top_k:][::-1] | |
results = [] | |
for idx in top_indices: | |
results.append(f""" | |
Similarity: {similarities[idx]:.3f} | |
Content: {chunks[idx]} | |
------------------- | |
""") | |
return "\n".join(results) | |
except Exception as e: | |
logger.error(f"Search error: {str(e)}") | |
return f"Search error occurred: {str(e)}" | |
def search_and_format(query, num_results): | |
if not query.strip(): | |
return "Please enter a search query" | |
return semantic_search(query, top_k=num_results) | |
def browse_outputs(): | |
try: | |
# Open the outputs directory in a web browser (may work on some systems) | |
webbrowser.open(f"file://{OUTPUTS_DIR}") | |
return "Opened outputs directory." | |
except Exception as e: | |
logger.error(f"Error opening file browser: {str(e)}") | |
return "Error opening file browser." | |
def download_results(): | |
required_files = ["embeddings.npy", "chunks.txt"] | |
missing = [f for f in required_files if not os.path.exists(os.path.join(OUTPUTS_DIR, f))] | |
if missing: | |
logger.error(f"Missing files: {missing}") | |
return None | |
try: | |
zip_path = os.path.join(OUTPUTS_DIR, "results.zip") | |
with zipfile.ZipFile(zip_path, 'w') as zipf: | |
for file in required_files: | |
file_path = os.path.join(OUTPUTS_DIR, file) | |
zipf.write(file_path, file) | |
return zip_path | |
except Exception as e: | |
logger.error(f"Error creating download archive: {str(e)}") | |
return None | |
def create_gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## Text Chunk Embeddings Generator") | |
error_box = gr.Textbox(visible=False, label="Status/Error Messages") | |
with gr.Row(): | |
file_input = gr.File( | |
label="Upload Text Files", | |
file_count="multiple", | |
file_types=[".txt"] | |
) | |
process_button = gr.Button("Generate Embeddings") | |
output_text = gr.Textbox(label="Status") | |
process_button.click( | |
fn=process_files, | |
inputs=[file_input], | |
outputs=[output_text, error_box, error_box] | |
) | |
with gr.Tab("Search"): | |
query_input = gr.Textbox( | |
label="Enter your search query", | |
placeholder="Enter text to search through your documents..." | |
) | |
top_k_slider = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=5, | |
step=1, | |
label="Number of results to return" | |
) | |
search_button = gr.Button(" Search") | |
results_output = gr.Textbox( | |
label="Search Results", | |
lines=10, | |
show_copy_button=True | |
) | |
search_button.click( | |
fn=search_and_format, | |
inputs=[query_input, top_k_slider], | |
outputs=results_output | |
) | |
download_button = gr.Button(" Download Results") | |
download_button.click( | |
fn=download_results, | |
outputs=[gr.File(label="Download Results")] | |
) | |
with gr.Tab("Outputs"): | |
browse_button = gr.Button(" Browse Outputs") | |
browse_button.click( | |
fn=browse_outputs, | |
outputs=[gr.Textbox(label="Browse Status")] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch(server_name="0.0.0.0") |