ClipQuery / qa_engine.py
maguid28's picture
initial commit
45b9636
raw
history blame
4.15 kB
import os, json
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_community.chat_models import ChatOllama
from langchain_community.llms import HuggingFaceHub
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.language_models.base import BaseLanguageModel
import logging
from langchain.globals import set_debug
# Enable verbose LangChain logging and write raw JSON lines to disk for analysis.
set_debug(True)
_lc_logger = logging.getLogger("langchain")
if not any(isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", "").endswith("langchain_debug.jsonl") for h in _lc_logger.handlers):
_fh = logging.FileHandler("langchain_debug.jsonl", mode="a", encoding="utf-8")
_fh.setFormatter(logging.Formatter("%(message)s"))
_lc_logger.addHandler(_fh)
_lc_logger.setLevel(logging.DEBUG)
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
def load_index(index_dir: str = "data"):
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
store = FAISS.load_local(index_dir, embeddings, allow_dangerous_deserialization=True)
with open(os.path.join(index_dir, "segments.json")) as f:
segments = json.load(f)
return store, segments
class JSONLCallbackHandler(BaseCallbackHandler):
"""Write simple LangChain events to a JSONL file so UI can display them."""
def __init__(self, path: str = "langchain_debug.jsonl"):
self.path = path
# Clear previous logs
open(self.path, "w").close()
def _write(self, record):
import json, time
record["ts"] = time.time()
with open(self.path, "a", encoding="utf-8") as f:
f.write(json.dumps(record) + "\n")
def on_chain_start(self, serialized, inputs, **kwargs):
self._write({"event": "chain_start", "name": serialized.get("name"), "inputs": inputs})
def on_chain_end(self, outputs, **kwargs):
self._write({"event": "chain_end", "outputs": outputs})
def on_llm_start(self, serialized, prompts, **kwargs):
self._write({"event": "llm_start", "prompts": prompts})
def on_llm_end(self, response, **kwargs):
self._write({"event": "llm_end", "response": str(response)})
def on_retriever_end(self, documents, **kwargs):
from langchain.docstore.document import Document
preview = [doc.page_content[:200] if isinstance(doc, Document) else str(doc) for doc in documents]
self._write({"event": "retriever_end", "documents": preview})
def get_model(model_name: str, hf_token: str = None, callbacks: list = None) -> BaseLanguageModel:
"""Return a model instance based on the model name.
Args:
model_name: Name of the model to use
hf_token: Hugging Face API token (required for flan-t5-base)
callbacks: List of callbacks to use
"""
if model_name == "flan-t5-base":
if not hf_token:
raise ValueError(
"Hugging Face API token is required for flan-t5-base. "
"Please provide your Hugging Face token in the UI or use a local model."
)
return HuggingFaceHub(
repo_id="google/flan-t5-base",
huggingfacehub_api_token=hf_token,
model_kwargs={"temperature": 0.1, "max_length": 512},
callbacks=callbacks
)
else:
return ChatOllama(model=model_name, callbacks=callbacks)
def build_chain(store, model_name: str = "phi3", hf_token: str = None):
"""Return a RetrievalQA chain using the specified model.
Args:
store: Vector store with document embeddings
model_name: Name of the model to use
hf_token: Hugging Face API token (required for flan-t5-base)
"""
callback = JSONLCallbackHandler()
llm = get_model(model_name, hf_token, [callback])
return RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
retriever=store.as_retriever(k=4, callbacks=[callback]),
return_source_documents=True,
verbose=True,
)