File size: 7,770 Bytes
0cc1518
 
 
afd7f5e
 
 
 
 
 
 
2707707
fb5aa44
afd7f5e
 
 
 
 
 
 
 
 
 
2707707
 
 
afd7f5e
fb5aa44
 
 
 
2707707
 
 
 
afd7f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2707707
afd7f5e
 
 
 
 
2707707
fb5aa44
2707707
 
fb5aa44
 
 
 
 
 
 
 
 
 
 
2707707
fb5aa44
 
 
 
2707707
 
fb5aa44
 
2707707
 
 
 
 
 
fb5aa44
 
2707707
 
 
fb5aa44
2707707
fb5aa44
afd7f5e
 
fb5aa44
 
 
 
 
afd7f5e
 
fb5aa44
afd7f5e
fb5aa44
afd7f5e
 
 
 
fb5aa44
 
 
 
 
 
 
afd7f5e
 
 
 
 
 
 
 
 
 
fb5aa44
 
afd7f5e
fb5aa44
 
 
 
 
afd7f5e
fb5aa44
 
afd7f5e
 
 
fb5aa44
 
 
 
afd7f5e
fb5aa44
afd7f5e
 
 
 
 
 
 
c4ef8ee
2707707
 
afd7f5e
fb5aa44
 
 
 
 
 
2707707
fb5aa44
2707707
fb5aa44
 
2707707
 
 
 
 
fb5aa44
2707707
 
 
 
 
 
 
afd7f5e
 
2707707
 
afd7f5e
 
 
 
fb5aa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2707707
fb5aa44
2707707
 
fb5aa44
 
 
 
 
2707707
 
fb5aa44
 
 
 
 
 
 
2707707
afd7f5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import subprocess
subprocess.run("pip install llama-cpp-python==0.3.15", shell=True, check=True)

import gradio as gr
import hopsworks
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
import faiss
import numpy as np
import os
import json
import yaml
from dotenv import load_dotenv

# 1. Load Environment Variables & Validation
load_dotenv()

HOPSWORKS_API_KEY = os.getenv("HOPSWORKS_API_KEY")

if not HOPSWORKS_API_KEY:
    raise ValueError("HOPSWORKS_API_KEY not found in environment variables.")

# Load models configuration
with open("models_config.json", "r") as f:
    models_config = json.load(f)

# Load RAG prompt configuration
with open("prompts/rag_prompt.yml", "r") as f:
    prompt_config = yaml.safe_load(f)

# Global variable to store the current LLM
llm = None

print("Initializing embeddings and connecting to Hopsworks...")

try:
    embeddings = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    project = hopsworks.login(api_key_value=HOPSWORKS_API_KEY)
    fs = project.get_feature_store()
    book_fg = fs.get_feature_group("book_embeddings", version=1)

    df = book_fg.read()
    
    if df.empty:
        raise ValueError("Feature group 'book_embeddings' is empty.")

    texts = df['text'].tolist()
    raw_embeddings = [emb if isinstance(emb, list) else emb.tolist() for emb in df['embedding']]
    embedding_vectors = np.array(raw_embeddings, dtype='float32')

    dimension = embedding_vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    
    faiss.normalize_L2(embedding_vectors)
    index.add(embedding_vectors)

    print("Embeddings and FAISS index initialized.")

except Exception as e:
    print(f"Critical Error during initialization: {e}")
    index = None

# Function to load a model dynamically
def load_model(repo_name, model_name, progress=gr.Progress()):
    global llm
    try:
        progress(0, desc="Initializing...")

        # Find the repository
        repo = next((r for r in models_config["repositories"] if r["name"] == repo_name), None)
        if not repo:
            return f"Error: Repository '{repo_name}' not found in config."

        # Find the model within the repository
        model = next((m for m in repo["models"] if m["name"] == model_name), None)
        if not model:
            return f"Error: Model '{model_name}' not found in repository."

        print(f"Loading model: {model['name']}...")
        print(f"Repo: {repo['repo_id']}, File: {model['filename']}")

        progress(0.3, desc=f"Downloading/Loading {model['name']}...")

        llm = Llama.from_pretrained(
            repo_id=repo["repo_id"],
            filename=model["filename"],
            n_ctx=2048,
            n_threads=4,
            n_gpu_layers=-1,
            verbose=False
        )

        progress(1.0, desc="Complete!")
        return f"✅ Model '{model_name}' loaded successfully!"

    except Exception as e:
        llm = None
        return f"❌ Error loading model: {str(e)}"

def retrieve_context(query, k=None):
    if index is None:
        return "Error: Search index not initialized."

    # Use k from prompt config if not specified
    if k is None:
        k = prompt_config["rag"]["num_retrieved_chunks"]

    query_embedding = embeddings.encode(query).astype('float32').reshape(1, -1)
    faiss.normalize_L2(query_embedding)

    distances, indices = index.search(query_embedding, k)

    retrieved_texts = []
    for i in indices[0]:
        if 0 <= i < len(texts):
            retrieved_texts.append(texts[i])

    # Use separator from prompt config
    separator = prompt_config["rag"]["context_separator"]

    print(f"Retrieved {len(retrieved_texts)} context chunks for the query.")
    print("Similarities:", distances)
    return separator.join(retrieved_texts)

def respond(message, history):
    """
    Generator function for streaming response.
    gr.ChatInterface passes 'message' and 'history' automatically.
    """
    if llm is None:
        yield "System Error: Models failed to load. Check console logs."
        return

    # Retrieve context using config settings
    context = retrieve_context(message)

    # Build prompt from template
    prompt = prompt_config["template"].format(
        context=context,
        question=message
    )

    # Get generation parameters from config
    gen_params = prompt_config["generation"]

    output = llm(
        prompt,
        max_tokens=gen_params["max_tokens"],
        temperature=gen_params["temperature"],
        stop=gen_params["stop_sequences"],
        stream=True
    )

    partial_message = ""
    for chunk in output:
        text_chunk = chunk["choices"][0]["text"]
        partial_message += text_chunk
        yield partial_message

with gr.Blocks(title="Hopsworks RAG ChatBot") as demo:
    gr.Markdown("<h1 style='text-align: center; color: #1EB382'>Hopsworks ChatBot</h1>")

    # Model Selection Section
    with gr.Row():
        repo_dropdown = gr.Dropdown(
            choices=[r["name"] for r in models_config["repositories"]],
            label="Select Repository",
            value=models_config["repositories"][0]["name"],
            scale=2
        )
        model_dropdown = gr.Dropdown(
            choices=[m["name"] for m in models_config["repositories"][0]["models"]],
            label="Select Model",
            value=models_config["repositories"][0]["models"][0]["name"],
            scale=2
        )
        load_button = gr.Button("Load Model", variant="primary", scale=1)

    status_box = gr.Textbox(
        label="Status",
        value="⚠️ Please select a repository and model, then click 'Load Model'",
        interactive=False
    )

    # Model info display
    model_info = gr.Markdown("")

    # Chat Interface
    chat_interface = gr.ChatInterface(
        fn=respond,
        chatbot=gr.Chatbot(height=400),
        textbox=gr.Textbox(placeholder="Ask a question about your documents...", container=False, scale=7),
        examples=["What is the main topic of the documents?", "Summarize the key points."],
        cache_examples=False,
    )

    # Function to update model dropdown when repository changes
    def update_model_choices(repo_name):
        repo = next((r for r in models_config["repositories"] if r["name"] == repo_name), None)
        if repo and repo["models"]:
            model_choices = [m["name"] for m in repo["models"]]
            return gr.Dropdown(choices=model_choices, value=model_choices[0])
        return gr.Dropdown(choices=[], value=None)

    # Function to update model info display
    def update_model_info(repo_name, model_name):
        repo = next((r for r in models_config["repositories"] if r["name"] == repo_name), None)
        if not repo:
            return ""

        model = next((m for m in repo["models"] if m["name"] == model_name), None)
        if model:
            return f"**{model['name']}**\n\n{model['description']}\n\n Repository: `{repo['repo_id']}`\n\n File: `{model['filename']}`"
        return ""

    # Event handlers
    repo_dropdown.change(update_model_choices, inputs=[repo_dropdown], outputs=[model_dropdown])
    repo_dropdown.change(update_model_info, inputs=[repo_dropdown, model_dropdown], outputs=[model_info])
    model_dropdown.change(update_model_info, inputs=[repo_dropdown, model_dropdown], outputs=[model_info])
    load_button.click(load_model, inputs=[repo_dropdown, model_dropdown], outputs=[status_box])

    # Load default model info on startup
    demo.load(
        lambda: update_model_info(
            models_config["repositories"][0]["name"],
            models_config["repositories"][0]["models"][0]["name"]
        ),
        outputs=[model_info]
    )

if __name__ == "__main__":
    demo.launch(share=True)