import os import io import json import pathlib import shutil from typing import List, Tuple, Dict import gradio as gr import numpy as np import faiss from sentence_transformers import SentenceTransformer from pypdf import PdfReader import fitz # PyMuPDF from collections import defaultdict from openai import OpenAI # ========================= # LLM Endpoint # ========================= API_KEY = os.environ.get("API_KEY") if not API_KEY: raise RuntimeError("Missing API_KEY (set it in Hugging Face: Settings → Variables and secrets).") client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # Friendly labels for dropdown MODEL_LABELS = { "GPT": "gpt-oss:20b", "Deepseek": "deepseek-r1", "Gemma": "gemma3:27b", "Qwen": "qwen3-235b" } MODEL_MAPPING = { "gpt-oss:20b": "openai/gpt-oss-20b:free", "deepseek-r1": "deepseek/deepseek-r1:free", "gemma3:27b": "google/gemma-3-27b-it:free", "qwen3-235b": "qwen/qwen3-235b-a22b:free" } DEFAULT_MODEL_LABEL = "Deepseek" GEN_TEMPERATURE = 0.2 GEN_TOP_P = 0.95 GEN_MAX_TOKENS = 1024 EMB_MODEL_NAME = "intfloat/multilingual-e5-base" def choose_store_dir() -> Tuple[str, bool]: data_root = "/data" if os.path.isdir(data_root) and os.access(data_root, os.W_OK): d = os.path.join(data_root, "rag_store") try: os.makedirs(d, exist_ok=True) testf = os.path.join(d, ".write_test") with open(testf, "w", encoding="utf-8") as f: f.write("ok") os.remove(testf) return d, True except Exception: pass d = os.path.join(os.getcwd(), "store") os.makedirs(d, exist_ok=True) return d, False STORE_DIR, IS_PERSISTENT = choose_store_dir() META_PATH = os.path.join(STORE_DIR, "meta.json") INDEX_PATH = os.path.join(STORE_DIR, "faiss.index") LEGACY_STORE_DIR = os.path.join(os.getcwd(), "store") def migrate_legacy_if_any(): try: if IS_PERSISTENT: legacy_meta = os.path.join(LEGACY_STORE_DIR, "meta.json") legacy_index = os.path.join(LEGACY_STORE_DIR, "faiss.index") if (not os.path.exists(META_PATH) or not os.path.exists(INDEX_PATH)) \ and os.path.isdir(LEGACY_STORE_DIR) \ and os.path.exists(legacy_meta) and os.path.exists(legacy_index): shutil.copyfile(legacy_meta, META_PATH) shutil.copyfile(legacy_index, INDEX_PATH) except Exception: pass migrate_legacy_if_any() _emb_model = None _index: faiss.Index = None _meta: Dict[str, Dict] = {} DEFAULT_TOP_K = 6 DEFAULT_POOL_K = 40 DEFAULT_PER_SOURCE_CAP = 2 DEFAULT_STRATEGY = "mmr" DEFAULT_MMR_LAMBDA = 0.5 def get_emb_model(): global _emb_model if _emb_model is None: _emb_model = SentenceTransformer(EMB_MODEL_NAME) return _emb_model def _ensure_index(dim: int): global _index if _index is None: _index = faiss.IndexFlatIP(dim) def _persist(): faiss.write_index(_index, INDEX_PATH) with open(META_PATH, "w", encoding="utf-8") as f: json.dump(_meta, f, ensure_ascii=False) def _load_if_any(): global _index, _meta if os.path.exists(INDEX_PATH) and os.path.exists(META_PATH): _index = faiss.read_index(INDEX_PATH) with open(META_PATH, "r", encoding="utf-8") as f: _meta = json.load(f) def _chunk_text(text: str, chunk_size: int = 800, overlap: int = 120) -> List[str]: text = text.replace("\u0000", "") res, i, n = [], 0, len(text) while i < n: j = min(i + chunk_size, n) seg = text[i:j].strip() if seg: res.append(seg) i = max(0, j - overlap) if j >= n: break return res def _read_bytes(file) -> bytes: if isinstance(file, dict): p = file.get("path") or file.get("name") if p and os.path.exists(p): with open(p, "rb") as f: return f.read() if "data" in file and isinstance(file["data"], (bytes, bytearray)): return bytes(file["data"]) if isinstance(file, (str, pathlib.Path)): with open(file, "rb") as f: return f.read() if hasattr(file, "read"): try: if hasattr(file, "seek"): try: file.seek(0) except Exception: pass return file.read() finally: try: file.close() except Exception: pass raise ValueError("Unsupported file type from gr.File") def _decode_best_effort(raw: bytes) -> str: for enc in ["utf-8", "cp932", "shift_jis", "cp950", "big5", "gb18030", "latin-1"]: try: return raw.decode(enc) except Exception: continue return raw.decode("utf-8", errors="ignore") def _read_pdf(file_bytes: bytes) -> str: try: with fitz.open(stream=file_bytes, filetype="pdf") as doc: if doc.is_encrypted: try: doc.authenticate("") except Exception: pass texts = [(page.get_text("text") or "") for page in doc] txt = "\n".join(texts) if txt.strip(): return txt except Exception: pass try: reader = PdfReader(io.BytesIO(file_bytes)) pages = [] for p in reader.pages: try: pages.append(p.extract_text() or "") except Exception: pages.append("") return "\n".join(pages) except Exception: return "" def _read_any(file) -> str: if isinstance(file, dict): name = (file.get("orig_name") or file.get("name") or file.get("path") or "upload").lower() else: name = getattr(file, "name", None) or (str(file) if isinstance(file, (str, pathlib.Path)) else "upload") name = name.lower() raw = _read_bytes(file) if name.endswith(".pdf"): return _read_pdf(raw).replace("\u0000", "") return _decode_best_effort(raw).replace("\u0000", "") DOCS_DIR = os.path.join(os.getcwd(), "docs") def get_docs_files() -> List[str]: if not os.path.isdir(DOCS_DIR): return [] files = [] for fname in os.listdir(DOCS_DIR): if fname.lower().endswith((".pdf", ".txt")): files.append(os.path.join(DOCS_DIR, fname)) return files def build_corpus_from_docs(): global _index, _meta files = get_docs_files() if not files: return "No files found in docs folder." emb_model = get_emb_model() chunks, sources, failed = [], [], [] _index = None _meta = {} for f in files: fname = os.path.basename(f) try: text = _read_any(f) or "" parts = _chunk_text(text) if not parts: failed.append(fname) continue chunks.extend(parts) sources.extend([fname] * len(parts)) except Exception: failed.append(fname) if not chunks: return "No text extracted from docs." passages = [f"passage: {c}" for c in chunks] vec = emb_model.encode(passages, batch_size=64, convert_to_numpy=True, normalize_embeddings=True) _ensure_index(vec.shape[1]) _index.add(vec) for i, (src, c) in enumerate(zip(sources, chunks)): _meta[str(i)] = {"source": src, "text": c} _persist() msg = f"Indexed {len(chunks)} chunks from {len(files)} files." if failed: msg += f" Failed files: {', '.join(failed)}" return msg def _encode_query_vec(query: str) -> np.ndarray: return get_emb_model().encode([f"query: {query}"], convert_to_numpy=True, normalize_embeddings=True) def retrieve_candidates(qvec: np.ndarray, pool_k: int = 40) -> List[Tuple[str, float]]: if _index is None or _index.ntotal == 0: return [] pool_k = min(pool_k, _index.ntotal) D, I = _index.search(qvec, pool_k) return [(str(idx), float(score)) for idx, score in zip(I[0], D[0]) if idx != -1] def select_diverse_by_source(cands: List[Tuple[str, float]], top_k: int = 6, per_source_cap: int = 2) -> List[Tuple[str, float]]: if not cands: return [] by_src: Dict[str, List[Tuple[str, float]]] = defaultdict(list) for cid, s in cands: m = _meta.get(cid) if not m: continue by_src[m["source"]].append((cid, s)) for src in by_src: by_src[src] = by_src[src][:per_source_cap] picked, src_items, ptrs = [], [(s, it) for s, it in by_src.items()], {s: 0 for s in by_src} while len(picked) < top_k: advanced = False for src, items in src_items: i = ptrs[src] if i < len(items): picked.append(items[i]) ptrs[src] = i + 1 advanced = True if len(picked) >= top_k: break if not advanced: break if len(picked) < top_k: seen = {cid for cid, _ in picked} for cid, s in cands: if cid not in seen: picked.append((cid, s)) seen.add(cid) if len(picked) >= top_k: break return picked[:top_k] def _encode_chunks_text(cids: List[str]) -> np.ndarray: texts = [f"passage: {(_meta.get(cid) or {}).get('text','')}" for cid in cids] return get_emb_model().encode(texts, convert_to_numpy=True, normalize_embeddings=True) def select_diverse_mmr(cands: List[Tuple[str, float]], qvec: np.ndarray, top_k: int = 6, mmr_lambda: float = 0.5) -> List[Tuple[str, float]]: if not cands: return [] cids = [cid for cid, _ in cands] cvecs = _encode_chunks_text(cids) sim_to_q = (cvecs @ qvec.T).reshape(-1) selected, remaining = [], set(range(len(cids))) while len(selected) < min(top_k, len(cids)): if not selected: i = int(np.argmax(sim_to_q)) selected.append(i) remaining.remove(i) continue S = cvecs[selected] sim_to_S = (cvecs[list(remaining)] @ S.T) max_sim_to_S = sim_to_S.max(axis=1) if sim_to_S.size > 0 else np.zeros((len(remaining),), dtype=np.float32) sim_q_rem = sim_to_q[list(remaining)] mmr_scores = mmr_lambda * sim_q_rem - (1.0 - mmr_lambda) * max_sim_to_S j_rel = int(np.argmax(mmr_scores)) j = list(remaining)[j_rel] selected.append(j) remaining.remove(j) return [(cids[i], float(sim_to_q[i])) for i in selected][:top_k] def retrieve_diverse(query: str, top_k: int = 6, pool_k: int = 40, per_source_cap: int = 2, strategy: str = "mmr", mmr_lambda: float = 0.5) -> List[Tuple[str, float]]: qvec = _encode_query_vec(query) cands = retrieve_candidates(qvec, pool_k=pool_k) if strategy == "mmr": return select_diverse_mmr(cands, qvec, top_k=top_k, mmr_lambda=mmr_lambda) return select_diverse_by_source(cands, top_k=top_k, per_source_cap=per_source_cap) def _format_ctx(hits: List[Tuple[str, float]]) -> str: if not hits: return "" lines = [] for cid, _ in hits: m = _meta.get(cid) if not m: continue source_clean = m.get("source", "") text_clean = (m.get("text", "") or "").replace("\n", " ") lines.append(f"[{cid}] ({source_clean}) " + text_clean) return "\n".join(lines[:10]) def chat_fn(message, history, model_label): # Map label to model key model_key = MODEL_LABELS.get(model_label, MODEL_LABELS[DEFAULT_MODEL_LABEL]) model_name = MODEL_MAPPING.get(model_key, MODEL_MAPPING[MODEL_LABELS[DEFAULT_MODEL_LABEL]]) if _index is None or _index.ntotal == 0: status = build_corpus_from_docs() if not (_index and _index.ntotal > 0): yield f"**Index Status:** {status}\n\nPlease ensure you have a 'docs' folder with PDF/TXT files and try again." return hits = retrieve_diverse( message, top_k=6, pool_k=40, per_source_cap=2, strategy="mmr", mmr_lambda=0.5, ) ctx = _format_ctx(hits) if hits else "(Current index is empty or no matching chunks found)" sys_blocks = [ "You are a rigorous enterprise research assistant. Your answers must be based on retrieved content with evidence and source numbers cited. If retrieval is insufficient, please clearly explain the shortcomings.", f"Below are the available reference chunks (with numbers and sources). When answering, please cite the numbers, e.g., [3].\n\n{ctx}", ] messages = [{"role": "system", "content": "\n\n".join(sys_blocks)}] for u, a in history: messages.append({"role": "user", "content": u}) messages.append({"role": "assistant", "content": a}) messages.append({"role": "user", "content": message}) try: response = client.chat.completions.create( model=model_name, messages=messages, temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, max_tokens=GEN_MAX_TOKENS, stream=True, ) partial_message = "" for chunk in response: if hasattr(chunk.choices[0], "delta") and chunk.choices[0].delta.content is not None: partial_message += chunk.choices[0].delta.content yield partial_message elif hasattr(chunk.choices[0], "message") and chunk.choices[0].message.content is not None: partial_message += chunk.choices[0].message.content yield partial_message except Exception as e: yield f"[Exception] {repr(e)}" with gr.Blocks(theme=gr.themes.Default(primary_hue="slate")) as maris: gr.Markdown("") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODEL_LABELS.keys()), value=DEFAULT_MODEL_LABEL, label="Select Model:", interactive=True ) with gr.Row(): query_box = gr.Textbox( label="Try: Who is Gopala", placeholder="Enter your question here...", scale=8 ) send_btn = gr.Button("Send", scale=1) with gr.Row(): chatbot = gr.Chatbot(label="Maris") state = gr.State([]) def chat_wrapper(user_message, history, model_label): history = history or [] gen = chat_fn(user_message, history, model_label) result = "" for chunk in gen: result = chunk history.append((user_message, result)) return history, history send_btn.click( chat_wrapper, inputs=[query_box, state, model_dropdown], outputs=[chatbot, state] ) query_box.submit( chat_wrapper, inputs=[query_box, state, model_dropdown], outputs=[chatbot, state] ) try: _load_if_any() except Exception: pass if __name__ == "__main__": maris.launch()