Polarisailabs commited on
Commit
1003bbb
·
verified ·
1 Parent(s): c95acdf

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +13 -0
  2. app.py +453 -0
  3. requirements.txt +7 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MarisTest
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import pathlib
5
+ import shutil
6
+ from typing import List, Tuple, Dict
7
+ import gradio as gr
8
+ import numpy as np
9
+ import faiss
10
+ from sentence_transformers import SentenceTransformer
11
+ from pypdf import PdfReader
12
+ import fitz # PyMuPDF
13
+ from collections import defaultdict
14
+ from openai import OpenAI
15
+
16
+ # =========================
17
+ # LLM Endpoint
18
+ # =========================
19
+
20
+ API_KEY = os.environ.get("OPENROUTER_API_KEY")
21
+ if not API_KEY:
22
+ raise RuntimeError("Missing OPENROUTER_API_KEY (set it in Hugging Face: Settings → Variables and secrets).")
23
+
24
+ client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY)
25
+
26
+ # Model mapping: display name -> actual model name
27
+ MODEL_MAPPING = {
28
+ "gpt-oss:120b": "openai/gpt-oss-120b:free",
29
+ "gpt-oss:20b": "openai/gpt-oss-20b:free",
30
+ "deepseek-r1": "deepseek/deepseek-r1:free",
31
+ "gemma3:27b": "google/gemma-3-27b-it:free",
32
+ "qwen3-235b": "qwen/qwen3-235b-a22b:free"
33
+ }
34
+
35
+ DEFAULT_MODEL_LABEL = "deepseek-r1"
36
+ GEN_TEMPERATURE = 0.2
37
+ GEN_TOP_P = 0.95
38
+ GEN_MAX_TOKENS = 1024
39
+
40
+ # =========================
41
+ # Vector Retrieval/Persistence (prefer /data; otherwise ./store)
42
+ # =========================
43
+
44
+ EMB_MODEL_NAME = "intfloat/multilingual-e5-base" # Supports both Chinese and English
45
+
46
+ def choose_store_dir() -> Tuple[str, bool]:
47
+ """Return (store_dir, is_persistent). Prefer /data/rag_store if writable."""
48
+ data_root = "/data"
49
+ if os.path.isdir(data_root) and os.access(data_root, os.W_OK):
50
+ d = os.path.join(data_root, "rag_store")
51
+ try:
52
+ os.makedirs(d, exist_ok=True)
53
+ testf = os.path.join(d, ".write_test")
54
+ with open(testf, "w", encoding="utf-8") as f:
55
+ f.write("ok")
56
+ os.remove(testf)
57
+ return d, True
58
+ except Exception:
59
+ pass
60
+ d = os.path.join(os.getcwd(), "store")
61
+ os.makedirs(d, exist_ok=True)
62
+ return d, False
63
+
64
+ STORE_DIR, IS_PERSISTENT = choose_store_dir()
65
+ META_PATH = os.path.join(STORE_DIR, "meta.json")
66
+ INDEX_PATH = os.path.join(STORE_DIR, "faiss.index")
67
+
68
+ # One-time migration from old ./store to /data/rag_store (if not yet exists)
69
+ LEGACY_STORE_DIR = os.path.join(os.getcwd(), "store")
70
+
71
+ def migrate_legacy_if_any():
72
+ try:
73
+ if IS_PERSISTENT:
74
+ legacy_meta = os.path.join(LEGACY_STORE_DIR, "meta.json")
75
+ legacy_index = os.path.join(LEGACY_STORE_DIR, "faiss.index")
76
+ if (not os.path.exists(META_PATH) or not os.path.exists(INDEX_PATH)) \
77
+ and os.path.isdir(LEGACY_STORE_DIR) \
78
+ and os.path.exists(legacy_meta) and os.path.exists(legacy_index):
79
+ shutil.copyfile(legacy_meta, META_PATH)
80
+ shutil.copyfile(legacy_index, INDEX_PATH)
81
+ except Exception:
82
+ pass
83
+
84
+ migrate_legacy_if_any()
85
+
86
+ _emb_model = None
87
+ _index: faiss.Index = None
88
+ _meta: Dict[str, Dict] = {}
89
+
90
+ # ========== Adjustable Parameter Defaults ==========
91
+ DEFAULT_TOP_K = 6
92
+ DEFAULT_POOL_K = 40
93
+ DEFAULT_PER_SOURCE_CAP = 2
94
+ DEFAULT_STRATEGY = "mmr" # "mmr" or "round_robin"
95
+ DEFAULT_MMR_LAMBDA = 0.5
96
+
97
+ # ---------- Basic Tools ----------
98
+
99
+ def get_emb_model():
100
+ global _emb_model
101
+ if _emb_model is None:
102
+ _emb_model = SentenceTransformer(EMB_MODEL_NAME)
103
+ return _emb_model
104
+
105
+ def _ensure_index(dim: int):
106
+ global _index
107
+ if _index is None:
108
+ _index = faiss.IndexFlatIP(dim) # Normalize vectors first → inner product = cosine
109
+
110
+ def _persist():
111
+ faiss.write_index(_index, INDEX_PATH)
112
+ with open(META_PATH, "w", encoding="utf-8") as f:
113
+ json.dump(_meta, f, ensure_ascii=False)
114
+
115
+ def _load_if_any():
116
+ global _index, _meta
117
+ if os.path.exists(INDEX_PATH) and os.path.exists(META_PATH):
118
+ _index = faiss.read_index(INDEX_PATH)
119
+ with open(META_PATH, "r", encoding="utf-8") as f:
120
+ _meta = json.load(f)
121
+
122
+ def _chunk_text(text: str, chunk_size: int = 800, overlap: int = 120) -> List[str]:
123
+ text = text.replace("\u0000", "")
124
+ res, i, n = [], 0, len(text)
125
+ while i < n:
126
+ j = min(i + chunk_size, n)
127
+ seg = text[i:j].strip()
128
+ if seg:
129
+ res.append(seg)
130
+ i = max(0, j - overlap)
131
+ if j >= n:
132
+ break
133
+ return res
134
+
135
+ # ---------- Robust File Reading ----------
136
+
137
+ def _read_bytes(file) -> bytes:
138
+ if isinstance(file, dict):
139
+ p = file.get("path") or file.get("name")
140
+ if p and os.path.exists(p):
141
+ with open(p, "rb") as f:
142
+ return f.read()
143
+ if "data" in file and isinstance(file["data"], (bytes, bytearray)):
144
+ return bytes(file["data"])
145
+ if isinstance(file, (str, pathlib.Path)):
146
+ with open(file, "rb") as f:
147
+ return f.read()
148
+ if hasattr(file, "read"):
149
+ try:
150
+ if hasattr(file, "seek"):
151
+ try:
152
+ file.seek(0)
153
+ except Exception:
154
+ pass
155
+ return file.read()
156
+ finally:
157
+ try:
158
+ file.close()
159
+ except Exception:
160
+ pass
161
+ raise ValueError("Unsupported file type from gr.File")
162
+
163
+ def _decode_best_effort(raw: bytes) -> str:
164
+ for enc in ["utf-8", "cp932", "shift_jis", "cp950", "big5", "gb18030", "latin-1"]:
165
+ try:
166
+ return raw.decode(enc)
167
+ except Exception:
168
+ continue
169
+ return raw.decode("utf-8", errors="ignore")
170
+
171
+ def _read_pdf(file_bytes: bytes) -> str:
172
+ # 1) PyMuPDF
173
+ try:
174
+ with fitz.open(stream=file_bytes, filetype="pdf") as doc:
175
+ if doc.is_encrypted:
176
+ try:
177
+ doc.authenticate("") # Try empty password
178
+ except Exception:
179
+ pass
180
+ texts = [(page.get_text("text") or "") for page in doc]
181
+ txt = "\n".join(texts)
182
+ if txt.strip():
183
+ return txt
184
+ except Exception:
185
+ pass
186
+ # 2) fallback: pypdf
187
+ try:
188
+ reader = PdfReader(io.BytesIO(file_bytes))
189
+ pages = []
190
+ for p in reader.pages:
191
+ try:
192
+ pages.append(p.extract_text() or "")
193
+ except Exception:
194
+ pages.append("")
195
+ return "\n".join(pages)
196
+ except Exception:
197
+ return ""
198
+
199
+ def _read_any(file) -> str:
200
+ if isinstance(file, dict):
201
+ name = (file.get("orig_name") or file.get("name") or file.get("path") or "upload").lower()
202
+ else:
203
+ name = getattr(file, "name", None) or (str(file) if isinstance(file, (str, pathlib.Path)) else "upload")
204
+ name = name.lower()
205
+ raw = _read_bytes(file)
206
+ if name.endswith(".pdf"):
207
+ return _read_pdf(raw).replace("\u0000", "")
208
+ return _decode_best_effort(raw).replace("\u0000", "")
209
+
210
+ # ---------- Build Corpus ----------
211
+
212
+ def build_corpus(files) -> str:
213
+ if not files:
214
+ return "No files selected."
215
+ emb_model = get_emb_model()
216
+ chunks, sources, failed = [], [], []
217
+ total_chars = 0
218
+ for f in files:
219
+ if isinstance(f, dict):
220
+ fname = f.get("orig_name") or f.get("name") or f.get("path") or "uploaded"
221
+ else:
222
+ fname = getattr(f, "name", None) or (os.path.basename(f) if isinstance(f, (str, pathlib.Path)) else "uploaded")
223
+ try:
224
+ text = _read_any(f) or ""
225
+ total_chars += len(text)
226
+ parts = _chunk_text(text)
227
+ if not parts:
228
+ failed.append(fname)
229
+ continue
230
+ chunks.extend(parts)
231
+ sources.extend([fname] * len(parts))
232
+ except Exception as e:
233
+ failed.append(f"{fname} (err: {e})")
234
+
235
+ if not chunks:
236
+ tier = "Persistent (/data)" if IS_PERSISTENT else "Ephemeral (./store)"
237
+ return f"No text extracted (please check file type/encoding; read {total_chars} characters this time).\nCurrent storage path: {STORE_DIR} [{tier}]"
238
+
239
+ passages = [f"passage: {c}" for c in chunks] # e5 prefix
240
+ vec = emb_model.encode(passages, batch_size=64, convert_to_numpy=True, normalize_embeddings=True)
241
+ _ensure_index(vec.shape[1])
242
+ _index.add(vec)
243
+ base = len(_meta)
244
+ for i, (src, c) in enumerate(zip(sources, chunks)):
245
+ _meta[str(base + i)] = {"source": src, "text": c}
246
+ _persist()
247
+
248
+ msg = f"Indexing complete: added {len(chunks)} chunks; current total chunks in corpus ≈ {_index.ntotal}."
249
+ if failed:
250
+ preview = ", ".join(failed[:5])
251
+ more = "" if len(failed) <= 5 else f" (and {len(failed)-5} more not shown)"
252
+ msg += f"\nNote: {len(failed)} files failed to extract text or were empty: {preview}{more}"
253
+ tier = "Persistent (/data)" if IS_PERSISTENT else "Ephemeral (./store)"
254
+ msg += f"\nCurrent storage path: {STORE_DIR} [{tier}]"
255
+ return msg
256
+
257
+ # ---------- Retrieval (Candidate Pool) ----------
258
+
259
+ def _encode_query_vec(query: str) -> np.ndarray:
260
+ return get_emb_model().encode([f"query: {query}"], convert_to_numpy=True, normalize_embeddings=True)
261
+
262
+ def retrieve_candidates(qvec: np.ndarray, pool_k: int = DEFAULT_POOL_K) -> List[Tuple[str, float]]:
263
+ if _index is None or _index.ntotal == 0:
264
+ return []
265
+ pool_k = min(pool_k, _index.ntotal)
266
+ D, I = _index.search(qvec, pool_k)
267
+ return [(str(idx), float(score)) for idx, score in zip(I[0], D[0]) if idx != -1]
268
+
269
+ # ---------- Diversification Strategies ----------
270
+
271
+ def select_diverse_by_source(cands: List[Tuple[str, float]], top_k: int = DEFAULT_TOP_K, per_source_cap: int = DEFAULT_PER_SOURCE_CAP) -> List[Tuple[str, float]]:
272
+ if not cands:
273
+ return []
274
+ by_src: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
275
+ for cid, s in cands:
276
+ m = _meta.get(cid)
277
+ if not m:
278
+ continue
279
+ by_src[m["source"]].append((cid, s))
280
+ for src in by_src:
281
+ by_src[src] = by_src[src][:per_source_cap]
282
+ picked, src_items, ptrs = [], [(s, it) for s, it in by_src.items()], {s: 0 for s in by_src}
283
+ while len(picked) < top_k:
284
+ advanced = False
285
+ for src, items in src_items:
286
+ i = ptrs[src]
287
+ if i < len(items):
288
+ picked.append(items[i])
289
+ ptrs[src] = i + 1
290
+ advanced = True
291
+ if len(picked) >= top_k:
292
+ break
293
+ if not advanced:
294
+ break
295
+ if len(picked) < top_k:
296
+ seen = {cid for cid, _ in picked}
297
+ for cid, s in cands:
298
+ if cid not in seen:
299
+ picked.append((cid, s))
300
+ seen.add(cid)
301
+ if len(picked) >= top_k:
302
+ break
303
+ return picked[:top_k]
304
+
305
+ def _encode_chunks_text(cids: List[str]) -> np.ndarray:
306
+ texts = [f"passage: {(_meta.get(cid) or {}).get('text','')}" for cid in cids]
307
+ return get_emb_model().encode(texts, convert_to_numpy=True, normalize_embeddings=True)
308
+
309
+ def select_diverse_mmr(cands: List[Tuple[str, float]], qvec: np.ndarray, top_k: int = DEFAULT_TOP_K, mmr_lambda: float = DEFAULT_MMR_LAMBDA) -> List[Tuple[str, float]]:
310
+ if not cands:
311
+ return []
312
+ cids = [cid for cid, _ in cands]
313
+ cvecs = _encode_chunks_text(cids)
314
+ sim_to_q = (cvecs @ qvec.T).reshape(-1)
315
+ selected, remaining = [], set(range(len(cids)))
316
+ while len(selected) < min(top_k, len(cids)):
317
+ if not selected:
318
+ i = int(np.argmax(sim_to_q))
319
+ selected.append(i)
320
+ remaining.remove(i)
321
+ continue
322
+ S = cvecs[selected]
323
+ sim_to_S = (cvecs[list(remaining)] @ S.T)
324
+ max_sim_to_S = sim_to_S.max(axis=1) if sim_to_S.size > 0 else np.zeros((len(remaining),), dtype=np.float32)
325
+ sim_q_rem = sim_to_q[list(remaining)]
326
+ mmr_scores = mmr_lambda * sim_q_rem - (1.0 - mmr_lambda) * max_sim_to_S
327
+ j_rel = int(np.argmax(mmr_scores))
328
+ j = list(remaining)[j_rel]
329
+ selected.append(j)
330
+ remaining.remove(j)
331
+ return [(cids[i], float(sim_to_q[i])) for i in selected][:top_k]
332
+
333
+ def retrieve_diverse(query: str,
334
+ top_k: int = DEFAULT_TOP_K,
335
+ pool_k: int = DEFAULT_POOL_K,
336
+ per_source_cap: int = DEFAULT_PER_SOURCE_CAP,
337
+ strategy: str = DEFAULT_STRATEGY,
338
+ mmr_lambda: float = DEFAULT_MMR_LAMBDA) -> List[Tuple[str, float]]:
339
+ qvec = _encode_query_vec(query)
340
+ cands = retrieve_candidates(qvec, pool_k=pool_k)
341
+ if strategy == "mmr":
342
+ return select_diverse_mmr(cands, qvec, top_k=top_k, mmr_lambda=mmr_lambda)
343
+ return select_diverse_by_source(cands, top_k=top_k, per_source_cap=per_source_cap)
344
+
345
+ # ---------- Format Chunks ----------
346
+
347
+ def _format_ctx(hits: List[Tuple[str, float]]) -> str:
348
+ if not hits:
349
+ return ""
350
+ lines = []
351
+ for cid, _ in hits:
352
+ m = _meta.get(cid)
353
+ if not m:
354
+ continue
355
+ source_clean = m.get("source", "")
356
+ text_clean = (m.get("text", "") or "").replace("\n", " ")
357
+ lines.append(f"[{cid}] ({source_clean}) " + text_clean)
358
+ return "\n".join(lines[:10])
359
+
360
+ # =========================
361
+ # LLM Conversation (streaming)
362
+ # =========================
363
+
364
+ def chat_fn(message, history, model_label, strategy, top_k, pool_k, per_source_cap, mmr_lambda):
365
+ hits = retrieve_diverse(
366
+ message,
367
+ top_k=int(top_k),
368
+ pool_k=int(pool_k),
369
+ per_source_cap=int(per_source_cap),
370
+ strategy=str(strategy),
371
+ mmr_lambda=float(mmr_lambda),
372
+ )
373
+
374
+ ctx = _format_ctx(hits) if hits else "(Current index is empty or no matching chunks found)"
375
+
376
+ sys_blocks = [
377
+ "You are a rigorous real estate market research assistant. Your answers must be based on retrieved content with evidence and source numbers cited. If retrieval is insufficient, please clearly explain the shortcomings.",
378
+ f"Below are the available reference chunks (with numbers and sources). When answering, please cite the numbers, e.g., [3].\n\n{ctx}",
379
+ ]
380
+
381
+ messages = [{"role": "system", "content": "\n\n".join(sys_blocks)}]
382
+
383
+ for u, a in history:
384
+ messages.append({"role": "user", "content": u})
385
+ messages.append({"role": "assistant", "content": a})
386
+
387
+ messages.append({"role": "user", "content": message})
388
+
389
+ # Convert label to actual model name
390
+ model_name = MODEL_MAPPING.get(model_label, MODEL_MAPPING[DEFAULT_MODEL_LABEL])
391
+
392
+ try:
393
+ # Streaming response
394
+ response = client.chat.completions.create(
395
+ model=model_name,
396
+ messages=messages,
397
+ temperature=GEN_TEMPERATURE,
398
+ top_p=GEN_TOP_P,
399
+ max_tokens=GEN_MAX_TOKENS,
400
+ stream=True, # Enable streaming
401
+ )
402
+
403
+ # Yield chunks as they arrive
404
+ partial_message = ""
405
+ for chunk in response:
406
+ if chunk.choices[0].delta.content is not None:
407
+ partial_message += chunk.choices[0].delta.content
408
+ yield partial_message
409
+
410
+ except Exception as e:
411
+ yield f"[Exception] {repr(e)}"
412
+
413
+ # =========================
414
+ # Gradio Interface (Slider/Dropdown takes effect immediately)
415
+ # =========================
416
+
417
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="slate")) as maris:
418
+ tier = "Persistent `/data`" if IS_PERSISTENT else "Ephemeral `./store` (may be lost on restart)"
419
+ with gr.Row():
420
+ with gr.Column(scale=1):
421
+ files = gr.File(file_count="multiple", file_types=[".pdf", ".txt"], label="Upload Corpus (multiple files allowed)")
422
+ out = gr.Textbox(label="Index Status", interactive=False)
423
+ gr.Button("Build/Update Index", variant="primary").click(build_corpus, inputs=files, outputs=out)
424
+
425
+ with gr.Accordion("Parameters:", open=False):
426
+ strategy = gr.Dropdown(choices=["mmr", "round_robin"], value=DEFAULT_STRATEGY, label="Diversification Strategy")
427
+ top_k = gr.Slider(1, 12, value=DEFAULT_TOP_K, step=1, label="Number of chunks for model (top_k)")
428
+ pool_k = gr.Slider(10, 200, value=DEFAULT_POOL_K, step=5, label="Candidate pool size (pool_k)")
429
+ per_source_cap = gr.Slider(1, 5, value=DEFAULT_PER_SOURCE_CAP, step=1, label="Per-source limit (for round_robin)")
430
+ mmr_lambda = gr.Slider(0.0, 1.0, value=DEFAULT_MMR_LAMBDA, step=0.05, label="MMR λ (higher = closer to query)")
431
+
432
+ with gr.Column(scale=2):
433
+ # Add model selection dropdown with custom labels
434
+ model_dropdown = gr.Dropdown(
435
+ choices=list(MODEL_MAPPING.keys()),
436
+ value=DEFAULT_MODEL_LABEL,
437
+ label="Select Model",
438
+ interactive=True
439
+ )
440
+
441
+ gr.ChatInterface(
442
+ fn=chat_fn, # Streaming function
443
+ additional_inputs=[model_dropdown, strategy, top_k, pool_k, per_source_cap, mmr_lambda],
444
+ )
445
+
446
+ # Cold start: load existing index
447
+ try:
448
+ _load_if_any()
449
+ except Exception:
450
+ pass
451
+
452
+ if __name__ == "__main__":
453
+ maris.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ openai>=1.40.0
3
+ faiss-cpu>=1.7.4
4
+ sentence-transformers>=2.7.0
5
+ pypdf>=4.2.0
6
+ pymupdf>=1.24.9
7
+ numpy>=1.26.0