Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| import pathlib | |
| import shutil | |
| from typing import List, Tuple, Dict | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from pypdf import PdfReader | |
| import fitz # PyMuPDF | |
| from collections import defaultdict | |
| from openai import OpenAI | |
| # ========================= | |
| # LLM Endpoint | |
| # ========================= | |
| API_KEY = os.environ.get("API_KEY") | |
| if not API_KEY: | |
| raise RuntimeError("Missing API_KEY (set it in Hugging Face: Settings → Variables and secrets).") | |
| client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) | |
| # Friendly labels for dropdown | |
| MODEL_LABELS = { | |
| "GPT": "gpt-oss:20b", | |
| "Deepseek": "deepseek-r1", | |
| "Gemma": "gemma3:27b", | |
| "Qwen": "qwen3-235b" | |
| } | |
| MODEL_MAPPING = { | |
| "gpt-oss:20b": "openai/gpt-oss-20b:free", | |
| "deepseek-r1": "deepseek/deepseek-r1:free", | |
| "gemma3:27b": "google/gemma-3-27b-it:free", | |
| "qwen3-235b": "qwen/qwen3-235b-a22b:free" | |
| } | |
| DEFAULT_MODEL_LABEL = "Deepseek" | |
| GEN_TEMPERATURE = 0.2 | |
| GEN_TOP_P = 0.95 | |
| GEN_MAX_TOKENS = 1024 | |
| EMB_MODEL_NAME = "intfloat/multilingual-e5-base" | |
| def choose_store_dir() -> Tuple[str, bool]: | |
| data_root = "/data" | |
| if os.path.isdir(data_root) and os.access(data_root, os.W_OK): | |
| d = os.path.join(data_root, "rag_store") | |
| try: | |
| os.makedirs(d, exist_ok=True) | |
| testf = os.path.join(d, ".write_test") | |
| with open(testf, "w", encoding="utf-8") as f: | |
| f.write("ok") | |
| os.remove(testf) | |
| return d, True | |
| except Exception: | |
| pass | |
| d = os.path.join(os.getcwd(), "store") | |
| os.makedirs(d, exist_ok=True) | |
| return d, False | |
| STORE_DIR, IS_PERSISTENT = choose_store_dir() | |
| META_PATH = os.path.join(STORE_DIR, "meta.json") | |
| INDEX_PATH = os.path.join(STORE_DIR, "faiss.index") | |
| LEGACY_STORE_DIR = os.path.join(os.getcwd(), "store") | |
| def migrate_legacy_if_any(): | |
| try: | |
| if IS_PERSISTENT: | |
| legacy_meta = os.path.join(LEGACY_STORE_DIR, "meta.json") | |
| legacy_index = os.path.join(LEGACY_STORE_DIR, "faiss.index") | |
| if (not os.path.exists(META_PATH) or not os.path.exists(INDEX_PATH)) \ | |
| and os.path.isdir(LEGACY_STORE_DIR) \ | |
| and os.path.exists(legacy_meta) and os.path.exists(legacy_index): | |
| shutil.copyfile(legacy_meta, META_PATH) | |
| shutil.copyfile(legacy_index, INDEX_PATH) | |
| except Exception: | |
| pass | |
| migrate_legacy_if_any() | |
| _emb_model = None | |
| _index: faiss.Index = None | |
| _meta: Dict[str, Dict] = {} | |
| DEFAULT_TOP_K = 6 | |
| DEFAULT_POOL_K = 40 | |
| DEFAULT_PER_SOURCE_CAP = 2 | |
| DEFAULT_STRATEGY = "mmr" | |
| DEFAULT_MMR_LAMBDA = 0.5 | |
| def get_emb_model(): | |
| global _emb_model | |
| if _emb_model is None: | |
| _emb_model = SentenceTransformer(EMB_MODEL_NAME) | |
| return _emb_model | |
| def _ensure_index(dim: int): | |
| global _index | |
| if _index is None: | |
| _index = faiss.IndexFlatIP(dim) | |
| def _persist(): | |
| faiss.write_index(_index, INDEX_PATH) | |
| with open(META_PATH, "w", encoding="utf-8") as f: | |
| json.dump(_meta, f, ensure_ascii=False) | |
| def _load_if_any(): | |
| global _index, _meta | |
| if os.path.exists(INDEX_PATH) and os.path.exists(META_PATH): | |
| _index = faiss.read_index(INDEX_PATH) | |
| with open(META_PATH, "r", encoding="utf-8") as f: | |
| _meta = json.load(f) | |
| def _chunk_text(text: str, chunk_size: int = 800, overlap: int = 120) -> List[str]: | |
| text = text.replace("\u0000", "") | |
| res, i, n = [], 0, len(text) | |
| while i < n: | |
| j = min(i + chunk_size, n) | |
| seg = text[i:j].strip() | |
| if seg: | |
| res.append(seg) | |
| i = max(0, j - overlap) | |
| if j >= n: | |
| break | |
| return res | |
| def _read_bytes(file) -> bytes: | |
| if isinstance(file, dict): | |
| p = file.get("path") or file.get("name") | |
| if p and os.path.exists(p): | |
| with open(p, "rb") as f: | |
| return f.read() | |
| if "data" in file and isinstance(file["data"], (bytes, bytearray)): | |
| return bytes(file["data"]) | |
| if isinstance(file, (str, pathlib.Path)): | |
| with open(file, "rb") as f: | |
| return f.read() | |
| if hasattr(file, "read"): | |
| try: | |
| if hasattr(file, "seek"): | |
| try: | |
| file.seek(0) | |
| except Exception: | |
| pass | |
| return file.read() | |
| finally: | |
| try: | |
| file.close() | |
| except Exception: | |
| pass | |
| raise ValueError("Unsupported file type from gr.File") | |
| def _decode_best_effort(raw: bytes) -> str: | |
| for enc in ["utf-8", "cp932", "shift_jis", "cp950", "big5", "gb18030", "latin-1"]: | |
| try: | |
| return raw.decode(enc) | |
| except Exception: | |
| continue | |
| return raw.decode("utf-8", errors="ignore") | |
| def _read_pdf(file_bytes: bytes) -> str: | |
| try: | |
| with fitz.open(stream=file_bytes, filetype="pdf") as doc: | |
| if doc.is_encrypted: | |
| try: | |
| doc.authenticate("") | |
| except Exception: | |
| pass | |
| texts = [(page.get_text("text") or "") for page in doc] | |
| txt = "\n".join(texts) | |
| if txt.strip(): | |
| return txt | |
| except Exception: | |
| pass | |
| try: | |
| reader = PdfReader(io.BytesIO(file_bytes)) | |
| pages = [] | |
| for p in reader.pages: | |
| try: | |
| pages.append(p.extract_text() or "") | |
| except Exception: | |
| pages.append("") | |
| return "\n".join(pages) | |
| except Exception: | |
| return "" | |
| def _read_any(file) -> str: | |
| if isinstance(file, dict): | |
| name = (file.get("orig_name") or file.get("name") or file.get("path") or "upload").lower() | |
| else: | |
| name = getattr(file, "name", None) or (str(file) if isinstance(file, (str, pathlib.Path)) else "upload") | |
| name = name.lower() | |
| raw = _read_bytes(file) | |
| if name.endswith(".pdf"): | |
| return _read_pdf(raw).replace("\u0000", "") | |
| return _decode_best_effort(raw).replace("\u0000", "") | |
| DOCS_DIR = os.path.join(os.getcwd(), "docs") | |
| def get_docs_files() -> List[str]: | |
| if not os.path.isdir(DOCS_DIR): | |
| return [] | |
| files = [] | |
| for fname in os.listdir(DOCS_DIR): | |
| if fname.lower().endswith((".pdf", ".txt")): | |
| files.append(os.path.join(DOCS_DIR, fname)) | |
| return files | |
| def build_corpus_from_docs(): | |
| global _index, _meta | |
| files = get_docs_files() | |
| if not files: | |
| return "No files found in docs folder." | |
| emb_model = get_emb_model() | |
| chunks, sources, failed = [], [], [] | |
| _index = None | |
| _meta = {} | |
| for f in files: | |
| fname = os.path.basename(f) | |
| try: | |
| text = _read_any(f) or "" | |
| parts = _chunk_text(text) | |
| if not parts: | |
| failed.append(fname) | |
| continue | |
| chunks.extend(parts) | |
| sources.extend([fname] * len(parts)) | |
| except Exception: | |
| failed.append(fname) | |
| if not chunks: | |
| return "No text extracted from docs." | |
| passages = [f"passage: {c}" for c in chunks] | |
| vec = emb_model.encode(passages, batch_size=64, convert_to_numpy=True, normalize_embeddings=True) | |
| _ensure_index(vec.shape[1]) | |
| _index.add(vec) | |
| for i, (src, c) in enumerate(zip(sources, chunks)): | |
| _meta[str(i)] = {"source": src, "text": c} | |
| _persist() | |
| msg = f"Indexed {len(chunks)} chunks from {len(files)} files." | |
| if failed: | |
| msg += f" Failed files: {', '.join(failed)}" | |
| return msg | |
| def _encode_query_vec(query: str) -> np.ndarray: | |
| return get_emb_model().encode([f"query: {query}"], convert_to_numpy=True, normalize_embeddings=True) | |
| def retrieve_candidates(qvec: np.ndarray, pool_k: int = 40) -> List[Tuple[str, float]]: | |
| if _index is None or _index.ntotal == 0: | |
| return [] | |
| pool_k = min(pool_k, _index.ntotal) | |
| D, I = _index.search(qvec, pool_k) | |
| return [(str(idx), float(score)) for idx, score in zip(I[0], D[0]) if idx != -1] | |
| def select_diverse_by_source(cands: List[Tuple[str, float]], top_k: int = 6, per_source_cap: int = 2) -> List[Tuple[str, float]]: | |
| if not cands: | |
| return [] | |
| by_src: Dict[str, List[Tuple[str, float]]] = defaultdict(list) | |
| for cid, s in cands: | |
| m = _meta.get(cid) | |
| if not m: | |
| continue | |
| by_src[m["source"]].append((cid, s)) | |
| for src in by_src: | |
| by_src[src] = by_src[src][:per_source_cap] | |
| picked, src_items, ptrs = [], [(s, it) for s, it in by_src.items()], {s: 0 for s in by_src} | |
| while len(picked) < top_k: | |
| advanced = False | |
| for src, items in src_items: | |
| i = ptrs[src] | |
| if i < len(items): | |
| picked.append(items[i]) | |
| ptrs[src] = i + 1 | |
| advanced = True | |
| if len(picked) >= top_k: | |
| break | |
| if not advanced: | |
| break | |
| if len(picked) < top_k: | |
| seen = {cid for cid, _ in picked} | |
| for cid, s in cands: | |
| if cid not in seen: | |
| picked.append((cid, s)) | |
| seen.add(cid) | |
| if len(picked) >= top_k: | |
| break | |
| return picked[:top_k] | |
| def _encode_chunks_text(cids: List[str]) -> np.ndarray: | |
| texts = [f"passage: {(_meta.get(cid) or {}).get('text','')}" for cid in cids] | |
| return get_emb_model().encode(texts, convert_to_numpy=True, normalize_embeddings=True) | |
| def select_diverse_mmr(cands: List[Tuple[str, float]], qvec: np.ndarray, top_k: int = 6, mmr_lambda: float = 0.5) -> List[Tuple[str, float]]: | |
| if not cands: | |
| return [] | |
| cids = [cid for cid, _ in cands] | |
| cvecs = _encode_chunks_text(cids) | |
| sim_to_q = (cvecs @ qvec.T).reshape(-1) | |
| selected, remaining = [], set(range(len(cids))) | |
| while len(selected) < min(top_k, len(cids)): | |
| if not selected: | |
| i = int(np.argmax(sim_to_q)) | |
| selected.append(i) | |
| remaining.remove(i) | |
| continue | |
| S = cvecs[selected] | |
| sim_to_S = (cvecs[list(remaining)] @ S.T) | |
| max_sim_to_S = sim_to_S.max(axis=1) if sim_to_S.size > 0 else np.zeros((len(remaining),), dtype=np.float32) | |
| sim_q_rem = sim_to_q[list(remaining)] | |
| mmr_scores = mmr_lambda * sim_q_rem - (1.0 - mmr_lambda) * max_sim_to_S | |
| j_rel = int(np.argmax(mmr_scores)) | |
| j = list(remaining)[j_rel] | |
| selected.append(j) | |
| remaining.remove(j) | |
| return [(cids[i], float(sim_to_q[i])) for i in selected][:top_k] | |
| def retrieve_diverse(query: str, | |
| top_k: int = 6, | |
| pool_k: int = 40, | |
| per_source_cap: int = 2, | |
| strategy: str = "mmr", | |
| mmr_lambda: float = 0.5) -> List[Tuple[str, float]]: | |
| qvec = _encode_query_vec(query) | |
| cands = retrieve_candidates(qvec, pool_k=pool_k) | |
| if strategy == "mmr": | |
| return select_diverse_mmr(cands, qvec, top_k=top_k, mmr_lambda=mmr_lambda) | |
| return select_diverse_by_source(cands, top_k=top_k, per_source_cap=per_source_cap) | |
| def _format_ctx(hits: List[Tuple[str, float]]) -> str: | |
| if not hits: | |
| return "" | |
| lines = [] | |
| for cid, _ in hits: | |
| m = _meta.get(cid) | |
| if not m: | |
| continue | |
| source_clean = m.get("source", "") | |
| text_clean = (m.get("text", "") or "").replace("\n", " ") | |
| lines.append(f"[{cid}] ({source_clean}) " + text_clean) | |
| return "\n".join(lines[:10]) | |
| def chat_fn(message, history, model_label): | |
| # Map label to model key | |
| model_key = MODEL_LABELS.get(model_label, MODEL_LABELS[DEFAULT_MODEL_LABEL]) | |
| model_name = MODEL_MAPPING.get(model_key, MODEL_MAPPING[MODEL_LABELS[DEFAULT_MODEL_LABEL]]) | |
| if _index is None or _index.ntotal == 0: | |
| status = build_corpus_from_docs() | |
| if not (_index and _index.ntotal > 0): | |
| yield f"**Index Status:** {status}\n\nPlease ensure you have a 'docs' folder with PDF/TXT files and try again." | |
| return | |
| hits = retrieve_diverse( | |
| message, | |
| top_k=6, | |
| pool_k=40, | |
| per_source_cap=2, | |
| strategy="mmr", | |
| mmr_lambda=0.5, | |
| ) | |
| ctx = _format_ctx(hits) if hits else "(Current index is empty or no matching chunks found)" | |
| sys_blocks = [ | |
| "You are a rigorous enterprise research assistant. Your answers must be based on retrieved content with evidence and source numbers cited. If retrieval is insufficient, please clearly explain the shortcomings.", | |
| f"Below are the available reference chunks (with numbers and sources). When answering, please cite the numbers, e.g., [3].\n\n{ctx}", | |
| ] | |
| messages = [{"role": "system", "content": "\n\n".join(sys_blocks)}] | |
| for u, a in history: | |
| messages.append({"role": "user", "content": u}) | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": message}) | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| temperature=GEN_TEMPERATURE, | |
| top_p=GEN_TOP_P, | |
| max_tokens=GEN_MAX_TOKENS, | |
| stream=True, | |
| ) | |
| partial_message = "" | |
| for chunk in response: | |
| if hasattr(chunk.choices[0], "delta") and chunk.choices[0].delta.content is not None: | |
| partial_message += chunk.choices[0].delta.content | |
| yield partial_message | |
| elif hasattr(chunk.choices[0], "message") and chunk.choices[0].message.content is not None: | |
| partial_message += chunk.choices[0].message.content | |
| yield partial_message | |
| except Exception as e: | |
| yield f"[Exception] {repr(e)}" | |
| with gr.Blocks(theme=gr.themes.Default(primary_hue="slate")) as maris: | |
| gr.Markdown("") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_LABELS.keys()), | |
| value=DEFAULT_MODEL_LABEL, | |
| label="Select Model:", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| query_box = gr.Textbox( | |
| label="Try: Who is Gopala", | |
| placeholder="Enter your question here...", | |
| scale=8 | |
| ) | |
| send_btn = gr.Button("Send", scale=1) | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(label="Maris") | |
| state = gr.State([]) | |
| def chat_wrapper(user_message, history, model_label): | |
| history = history or [] | |
| gen = chat_fn(user_message, history, model_label) | |
| result = "" | |
| for chunk in gen: | |
| result = chunk | |
| history.append((user_message, result)) | |
| return history, history | |
| send_btn.click( | |
| chat_wrapper, | |
| inputs=[query_box, state, model_dropdown], | |
| outputs=[chatbot, state] | |
| ) | |
| query_box.submit( | |
| chat_wrapper, | |
| inputs=[query_box, state, model_dropdown], | |
| outputs=[chatbot, state] | |
| ) | |
| try: | |
| _load_if_any() | |
| except Exception: | |
| pass | |
| if __name__ == "__main__": | |
| maris.launch() |