File size: 10,988 Bytes
bd0e3b7
1b5d6e1
 
 
67787f1
1b5d6e1
 
 
 
 
160e875
09c1ee0
160e875
 
 
 
 
d4cee85
 
e2228da
1b5d6e1
 
 
 
 
 
 
 
 
 
e2228da
160e875
 
 
e2228da
d4cee85
160e875
 
d4cee85
160e875
1b5d6e1
e2228da
 
 
d4cee85
160e875
 
bb01969
160e875
 
1b5d6e1
5fafef7
 
 
160e875
1b5d6e1
160e875
1b5d6e1
 
 
 
 
 
 
 
 
5fafef7
 
 
 
 
 
1b5d6e1
 
 
bb01969
 
5fafef7
 
1b5d6e1
 
5fafef7
1b5d6e1
 
 
 
 
3618705
1b5d6e1
 
25e3a1a
1b5d6e1
bb01969
09c1ee0
1b5d6e1
 
 
 
 
 
 
 
 
 
 
 
 
25e3a1a
1b5d6e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fafef7
3618705
25e3a1a
1b5d6e1
 
 
 
 
5fafef7
1b5d6e1
 
 
 
 
 
 
 
 
 
 
bb01969
1b5d6e1
 
 
 
 
 
 
 
bb01969
1b5d6e1
 
 
 
 
e2228da
 
 
 
 
1b5d6e1
d4cee85
 
 
 
 
1b5d6e1
 
 
 
 
 
 
25e3a1a
 
 
1b5d6e1
 
 
 
 
 
160e875
e2228da
5fafef7
1b5d6e1
bb01969
 
 
 
 
1b5d6e1
160e875
3618705
1b5d6e1
bb01969
3618705
1b5d6e1
 
 
 
25e3a1a
1b5d6e1
 
 
 
 
 
 
 
 
 
25e3a1a
1b5d6e1
 
 
 
 
 
 
 
d4cee85
3618705
bb01969
e2228da
bb01969
25e3a1a
d4cee85
bb01969
1b5d6e1
e2228da
09c1ee0
 
 
 
 
d4cee85
09c1ee0
 
 
bb01969
 
09c1ee0
d4cee85
bb01969
e2228da
1b5d6e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb01969
 
 
 
 
 
1b5d6e1
 
 
 
 
bb01969
1b5d6e1
 
 
 
 
 
5fafef7
1b5d6e1
 
 
 
 
 
 
bb01969
1b5d6e1
 
 
5fafef7
e2228da
 
d4cee85
1b5d6e1
 
160e875
5fafef7
d4cee85
 
bb01969
1b5d6e1
 
 
 
 
 
5fafef7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import os
import gradio as gr
import logging
import traceback
import spaces
from typing import Optional, List
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import gc
import torch
from torch.amp import autocast
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import requests
from charset_normalizer import from_bytes
import zipfile
import tempfile
import webbrowser

# Custom Exception Class
class GPUQuotaExceededError(Exception):
    pass

# Constants
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_SIZE = 500
BATCH_SIZE = 32

# Set Persistent Storage Path
PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data")
os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777)

# Define Subdirectories
TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777)

OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777)

NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache")
os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)

LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)

# Set Hugging Face cache directory to persistent storage
os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)

# Set Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")

# Logging Setup
logging.basicConfig(
    filename=os.path.join(LOG_DIR, "app.log"),
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Model initialization
model = None

def initialize_model():
    """
    Initialize the sentence transformer model.

    Returns:
        bool: Whether the model was successfully initialized.
    """
    global model
    try:
        if model is None:
            model_cache = os.path.join(PERSISTENT_PATH, "models")
            os.makedirs(model_cache, exist_ok=True, mode=0o777)
            # Use the HF_TOKEN to load the model
            model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN)
            logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
        return True
    except requests.exceptions.RequestException as e:
        logger.error(f"Connection error during model download: {str(e)}\n{traceback.format_exc()}")
        return False
    except Exception as e:
        logger.error(f"Model initialization failed: {str(e)}\n{traceback.format_exc()}")
        return False

@spaces.GPU
def handle_gpu_operation(func):
    try:
        start_time = datetime.now()
        # Updated autocast usage as per deprecation notice
        with autocast(device_type='cuda', dtype=torch.float16):
            result = func()
        end_time = datetime.now()
        duration = (end_time - start_time).total_seconds()
        logger.info(f"GPU operation completed in {duration:.2f}s")
        return result
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            torch.cuda.empty_cache()
            logger.error(f"GPU memory error: {str(e)}")
            raise GPUQuotaExceededError("GPU memory limit exceeded. Please try with a smaller batch.")
        else:
            logger.error(f"GPU runtime error: {str(e)}")
            raise
    except Exception as e:
        if "quota exceeded" in str(e).lower():
            logger.error(f"GPU quota exceeded: {str(e)}")
            raise GPUQuotaExceededError("GPU quota exceeded. Please wait a few minutes before trying again.")
        else:
            logger.error(f"Unexpected GPU error: {str(e)}")
            raise

def get_model():
    global model
    if model is None:
        if torch.cuda.is_available():
            initialize_model()
        else:
            logger.warning("Attempted to initialize model outside GPU context, deferring.")
            return None
    return model

@spaces.GPU
def process_files(files):
    if not files:
        return "Please upload one or more.txt files.", "", ""

    try:
        if not initialize_model():
            return "Failed to initialize the model. Please try again.", "", ""

        valid_files = [f for f in files if f.name.lower().endswith('.txt')]
        if not valid_files:
            return "No.txt files found. Please upload valid.txt files.", "", ""

        all_chunks = []
        processed_files = 0

        for file in valid_files:
            try:
                with open(file.name, 'rb') as f:
                    content = f.read()
                    detected_encoding = from_bytes(content).best().encoding
                    decoded_content = content.decode(detected_encoding, errors='ignore')

                # Split content into chunks
                chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)]
                all_chunks.extend(chunks)
                processed_files += 1
                logger.info(f"Processed file: {file.name}")
            except Exception as e:
                logger.error(f"Error processing file {file.name}: {str(e)}")

        if not all_chunks:
            return "No valid content found in the uploaded files.", "", ""

        # Generate embeddings in batches
        all_embeddings = []
        for i in range(0, len(all_chunks), BATCH_SIZE):
            batch = all_chunks[i:i+BATCH_SIZE]
            if model:
                embeddings = handle_gpu_operation(lambda: model.encode(batch))
                all_embeddings.extend(embeddings)
            else:
                return "Model not initialized. Please check model initialization.", "", ""

        # Save results to OUTPUTS_DIR
        embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy")
        np.save(embeddings_path, np.array(all_embeddings))
        chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt")
        with open(chunks_path, "w", encoding="utf-8") as f:
            for chunk in all_chunks:
                f.write(chunk + "\n===CHUNK_SEPARATOR===\n")

        return (
            f"Successfully processed {processed_files} files. Generated {len(all_embeddings)} embeddings from {len(all_chunks)} chunks.",
            "",
            ""
        )

    except Exception as e:
        logger.error(f"Processing failed: {str(e)}")
        return f"Error processing files: {str(e)}", "", ""

@spaces.GPU
def semantic_search(query, top_k=5):
    global model
    if model is None:
        return "Model not initialized. Please process files first."

    try:
        # Load saved embeddings and chunks from OUTPUTS_DIR
        embeddings_file = os.path.join(OUTPUTS_DIR, "embeddings.npy")
        chunks_file = os.path.join(OUTPUTS_DIR, "chunks.txt")
        stored_embeddings = np.load(embeddings_file)
        with open(chunks_file, "r", encoding="utf-8") as f:
            chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
            chunks = [c for c in chunks if c.strip()]

        # Get query embedding
        query_embedding = model.encode([query])[0]

        # Calculate similarities
        similarities = np.dot(stored_embeddings, query_embedding) / (
            np.linalg.norm(stored_embeddings, axis=1) * np.linalg.norm(query_embedding)
        )

        # Get top results
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        results = []
        for idx in top_indices:
            results.append(f"""
Similarity: {similarities[idx]:.3f}
Content: {chunks[idx]}
-------------------
""")
        return "\n".join(results)
    except Exception as e:
        logger.error(f"Search error: {str(e)}")
        return f"Search error occurred: {str(e)}"

def search_and_format(query, num_results):
    if not query.strip():
        return "Please enter a search query"
    return semantic_search(query, top_k=num_results)

def browse_outputs():
    try:
        # Open the outputs directory in a web browser (may work on some systems)
        webbrowser.open(f"file://{OUTPUTS_DIR}")
        return "Opened outputs directory."
    except Exception as e:
        logger.error(f"Error opening file browser: {str(e)}")
        return "Error opening file browser."

def download_results():
    required_files = ["embeddings.npy", "chunks.txt"]
    missing = [f for f in required_files if not os.path.exists(os.path.join(OUTPUTS_DIR, f))]
    if missing:
        logger.error(f"Missing files: {missing}")
        return None
    try:
        zip_path = os.path.join(OUTPUTS_DIR, "results.zip")
        with zipfile.ZipFile(zip_path, 'w') as zipf:
            for file in required_files:
                file_path = os.path.join(OUTPUTS_DIR, file)
                zipf.write(file_path, file)
        return zip_path
    except Exception as e:
        logger.error(f"Error creating download archive: {str(e)}")
        return None

def create_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## Text Chunk Embeddings Generator")

        error_box = gr.Textbox(visible=False, label="Status/Error Messages")

        with gr.Row():
            file_input = gr.File(
                label="Upload Text Files",
                file_count="multiple",
                file_types=[".txt"]
            )

        process_button = gr.Button("Generate Embeddings")
        output_text = gr.Textbox(label="Status")

        process_button.click(
            fn=process_files,
            inputs=[file_input],
            outputs=[output_text, error_box, error_box]
        )

        with gr.Tab("Search"):
            query_input = gr.Textbox(
                label="Enter your search query",
                placeholder="Enter text to search through your documents..."
            )
            top_k_slider = gr.Slider(
                minimum=1,
                maximum=20,
                value=5,
                step=1,
                label="Number of results to return"
            )
            search_button = gr.Button(" Search")
            results_output = gr.Textbox(
                label="Search Results",
                lines=10,
                show_copy_button=True
            )
            search_button.click(
                fn=search_and_format,
                inputs=[query_input, top_k_slider],
                outputs=results_output
            )

            download_button = gr.Button(" Download Results")
            download_button.click(
                fn=download_results,
                outputs=[gr.File(label="Download Results")]
            )

        with gr.Tab("Outputs"):
            browse_button = gr.Button(" Browse Outputs")
            browse_button.click(
                fn=browse_outputs,
                outputs=[gr.Textbox(label="Browse Status")]
            )

    return demo

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0")