File size: 10,717 Bytes
131a7a0
bd0e3b7
1b5d6e1
 
d1cdc5f
 
 
 
 
 
 
9c81028
4be0978
d1cdc5f
 
 
 
 
 
 
 
 
 
 
 
1b5d6e1
d1cdc5f
1b5d6e1
d1cdc5f
 
 
 
 
 
160e875
d1cdc5f
 
 
5fafef7
d1cdc5f
 
9c81028
d1cdc5f
 
 
 
 
 
 
 
 
 
 
 
 
 
1b5d6e1
d1cdc5f
1b5d6e1
 
 
 
 
 
 
d1cdc5f
1b5d6e1
 
d1cdc5f
 
 
 
 
 
4be0978
1b5d6e1
 
d1cdc5f
 
 
 
1b5d6e1
4be0978
 
 
d1cdc5f
 
 
 
 
 
 
4be0978
d1cdc5f
 
4be0978
3618705
d1cdc5f
9c81028
4be0978
9c74ac0
4be0978
 
 
3618705
25e3a1a
4be0978
9c74ac0
 
 
 
25e3a1a
9c81028
 
 
5fafef7
d1cdc5f
9c74ac0
1b5d6e1
9c74ac0
d1cdc5f
9c74ac0
 
9c81028
 
 
 
25e3a1a
d1cdc5f
9c74ac0
9c81028
d1cdc5f
9c74ac0
 
 
25e3a1a
9c81028
 
 
1b5d6e1
d1cdc5f
9c81028
4be0978
9c74ac0
4be0978
 
 
1b5d6e1
3618705
9c81028
9c74ac0
9c81028
9c74ac0
d1cdc5f
9c74ac0
4be0978
9c74ac0
 
 
 
 
 
 
 
 
 
 
 
 
 
d4cee85
9c74ac0
9c81028
 
1b5d6e1
9c74ac0
1b5d6e1
 
9c81028
1b5d6e1
4be0978
 
9c81028
 
 
1b5d6e1
 
9c81028
1b5d6e1
9c81028
9c74ac0
 
 
 
 
 
 
 
 
 
 
 
 
9c81028
 
9c74ac0
 
 
4be0978
 
 
 
 
9c81028
 
 
 
9c74ac0
bb01969
 
9c81028
 
9c74ac0
9c81028
 
 
 
 
9c74ac0
9c81028
 
1b5d6e1
9c81028
d1cdc5f
9c81028
 
 
 
 
 
 
9c74ac0
f57d5e5
1b5d6e1
 
 
 
4be0978
 
 
 
 
1b5d6e1
3417951
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import json
import os
import gradio as gr
import logging
import traceback
import spaces
from typing import Optional, List
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import gc
import torch
from torch.amp import autocast
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import requests
from charset_normalizer import from_bytes
import zipfile
import tempfile
import shutil

# Custom Exception Class (Keep this)
class GPUQuotaExceededError(Exception):
    pass

# Constants (Modified Persistent Paths and Cache)
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 500
BATCH_SIZE = 32

# Set Persistent Storage Path (More Explicit Paths - from Worked Code)
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)

# Define Subdirectories (More Explicit Paths)
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)

OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)

NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache")
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)

LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)

# Set Hugging Face cache directory to persistent storage (From Worked Code - Important!)
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)

# Set Hugging Face token (Keep this - best to use environment variable)
HF_TOKEN = os.getenv("HF_TOKEN")

# Logging Setup (Keep this - helpful for debugging)
logging.basicConfig(
    filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Model initialization
model = None
model_initialization_error = "" # Global variable for initialization error

def initialize_model():
    """
    Initialize the sentence transformer model with explicit cache path and error handling.
    Returns:
        bool: Whether the model was successfully initialized.
        str: Error message if initialization failed, otherwise empty string.
    """
    global model, model_initialization_error
    try:
        if model is None:
            model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code)
            os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists
            # Use the HF_TOKEN to load the model (as in worked code)
            model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN)
            logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
            model_initialization_error = "" # Clear any previous error
            return True, "" # Return success and no error message
        return True, "" # Already initialized, return success and no error
    except requests.exceptions.RequestException as e: # Specific network error handling (from worked code)
        error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        model_initialization_error = error_msg
        return False, error_msg
    except Exception as e: # General error handling (from worked code)
        error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        model_initialization_error = error_msg
        return False, error_msg


@spaces.GPU
def generate_embedding(text, focus):
    global model, model_initialization_error
    if model is None:
        success, error_message = initialize_model() # Call initialize_model and get status
        if not success:
            return "", error_message # Return initialization error to UI

    try:
        with torch.amp.autocast('cuda'):
            embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
        # Convert embedding to JSON string for direct display in UI
        embedding_json_str = json.dumps(embedding_vector)
        return embedding_json_str, "" # Return JSON string to UI
    except Exception as e:
        error_msg = f"Error generating embedding: {str(e)}"
        logger.error(error_msg)
        return "", error_msg

@spaces.GPU
def save_embedding(embedding_json, name): # Expect JSON string as input from UI
    try:
        embedding = json.loads(embedding_json) # Parse JSON string back to list
        filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath
        np.save(filepath, np.array(embedding))
        return f"Embedding saved to: {filepath}" # Return filepath in status
    except Exception as e:
        error_msg = f"Error saving embedding: {str(e)}"
        logger.error(error_msg)
        return error_msg

@spaces.GPU
def convert_to_json(embedding_json, name): # Expect JSON string as input
    try:
        filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath
        with open(filepath, "w") as f:
            f.write(embedding_json) # Directly write the JSON string
        return f"Embedding saved as JSON to: {filepath}" # Return filepath in status
    except Exception as e:
        error_msg = f"Error converting to JSON: {str(e)}"
        logger.error(error_msg)
        return error_msg

@spaces.GPU
def process_files(files, focus):
    global model, model_initialization_error
    if model is None:
        success, error_message = initialize_model() # Call initialize_model and get status
        if not success:
            return "", error_message # Return initialization error to UI

    try:
        all_embeddings = []
        file_statuses = [] # To track status for each file
        for file in files:
            try:
                with open(file.name, 'rb') as f:
                    text = f.read()
                with torch.amp.autocast('cuda'):
                    embedding = model.encode([text])[0].tolist()
                all_embeddings.append(embedding)
                file_statuses.append(f"File '{file.name}' processed successfully.")
            except Exception as file_e:
                error_msg = f"Error processing file '{file.name}': {str(file_e)}"
                logger.error(error_msg)
                file_statuses.append(error_msg)

        # Prepare status message for all files
        status_message = "\n".join(file_statuses)
        # Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files)
        all_embeddings_json = json.dumps(all_embeddings)

        return all_embeddings_json, status_message # Return JSON string and status message
    except Exception as e:
        error_msg = f"Error in process_files function: {str(e)}"
        logger.error(error_msg)
        return "", error_msg


def create_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## Text Embedding Generator")

        initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error

        with gr.Row():
            text_input = gr.Textbox(label="Enter Text")
            focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")

        with gr.Row():
            file_input = gr.File(label="Upload Files", file_count="multiple")

        generate_button = gr.Button("Generate Embedding")
        embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON
        status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box

        with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options
            save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)")
            with gr.Row():
                save_button = gr.Button("Save as .npy")
                convert_button = gr.Button("Save as .json")
            with gr.Row():
                save_status = gr.Textbox(label="Save Status")
                convert_status = gr.Textbox(label="Convert Status")
            download_button = gr.Button("Download JSON")
            download_output = gr.File(label="Download JSON File")

        process_button = gr.Button("Process Files")
        process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
        process_status = gr.Textbox(label="File Processing Status") # Status for file processing

        demo.load( # Call initialize_model on app load
            lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box
            outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box
        )


        generate_button.click(
            generate_embedding,
            inputs=[text_input, focus_input],
            outputs=[embedding_output, status_box] # Renamed error_box to status_box
        )

        save_button.click(
            save_embedding,
            inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string)
            outputs=[save_status]
        )

        convert_button.click(
            convert_to_json,
            inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string)
            outputs=[convert_status]
        )

        download_button.click(
            lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join
            inputs=[save_name_input],
            outputs=[download_output]
        )

        process_button.click(
            process_files,
            inputs=[file_input, focus_input],
            outputs=[process_output, process_status] # outputs for process_files
        )

    return demo

if __name__ == "__main__":
    # Explicitly initialize the model at app startup and check for errors
    initialization_success, initialization_error_message = initialize_model()
    if not initialization_success:
        print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors

    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0", allowed_paths=["/data"])