|
import os, json |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
from langchain_community.chat_models import ChatOllama |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain_core.language_models.base import BaseLanguageModel |
|
import logging |
|
from langchain.globals import set_debug |
|
|
|
|
|
set_debug(True) |
|
_lc_logger = logging.getLogger("langchain") |
|
if not any(isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", "").endswith("langchain_debug.jsonl") for h in _lc_logger.handlers): |
|
_fh = logging.FileHandler("langchain_debug.jsonl", mode="a", encoding="utf-8") |
|
_fh.setFormatter(logging.Formatter("%(message)s")) |
|
_lc_logger.addHandler(_fh) |
|
_lc_logger.setLevel(logging.DEBUG) |
|
|
|
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
|
|
|
def load_index(index_dir: str = "data"): |
|
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) |
|
store = FAISS.load_local(index_dir, embeddings, allow_dangerous_deserialization=True) |
|
with open(os.path.join(index_dir, "segments.json")) as f: |
|
segments = json.load(f) |
|
return store, segments |
|
|
|
|
|
class JSONLCallbackHandler(BaseCallbackHandler): |
|
"""Write simple LangChain events to a JSONL file so UI can display them.""" |
|
def __init__(self, path: str = "langchain_debug.jsonl"): |
|
self.path = path |
|
|
|
open(self.path, "w").close() |
|
|
|
def _write(self, record): |
|
import json, time |
|
record["ts"] = time.time() |
|
with open(self.path, "a", encoding="utf-8") as f: |
|
f.write(json.dumps(record) + "\n") |
|
|
|
def on_chain_start(self, serialized, inputs, **kwargs): |
|
self._write({"event": "chain_start", "name": serialized.get("name"), "inputs": inputs}) |
|
|
|
def on_chain_end(self, outputs, **kwargs): |
|
self._write({"event": "chain_end", "outputs": outputs}) |
|
|
|
def on_llm_start(self, serialized, prompts, **kwargs): |
|
self._write({"event": "llm_start", "prompts": prompts}) |
|
|
|
def on_llm_end(self, response, **kwargs): |
|
self._write({"event": "llm_end", "response": str(response)}) |
|
|
|
def on_retriever_end(self, documents, **kwargs): |
|
from langchain.docstore.document import Document |
|
preview = [doc.page_content[:200] if isinstance(doc, Document) else str(doc) for doc in documents] |
|
self._write({"event": "retriever_end", "documents": preview}) |
|
|
|
|
|
def get_model(model_name: str, hf_token: str = None, callbacks: list = None) -> BaseLanguageModel: |
|
"""Return a model instance based on the model name. |
|
|
|
Args: |
|
model_name: Name of the model to use |
|
hf_token: Hugging Face API token (required for flan-t5-base) |
|
callbacks: List of callbacks to use |
|
""" |
|
if model_name == "flan-t5-base": |
|
if not hf_token: |
|
raise ValueError( |
|
"Hugging Face API token is required for flan-t5-base. " |
|
"Please provide your Hugging Face token in the UI or use a local model." |
|
) |
|
return HuggingFaceHub( |
|
repo_id="google/flan-t5-base", |
|
huggingfacehub_api_token=hf_token, |
|
model_kwargs={"temperature": 0.1, "max_length": 512}, |
|
callbacks=callbacks |
|
) |
|
else: |
|
return ChatOllama(model=model_name, callbacks=callbacks) |
|
|
|
|
|
def build_chain(store, model_name: str = "phi3", hf_token: str = None): |
|
"""Return a RetrievalQA chain using the specified model. |
|
|
|
Args: |
|
store: Vector store with document embeddings |
|
model_name: Name of the model to use |
|
hf_token: Hugging Face API token (required for flan-t5-base) |
|
""" |
|
callback = JSONLCallbackHandler() |
|
llm = get_model(model_name, hf_token, [callback]) |
|
return RetrievalQAWithSourcesChain.from_chain_type( |
|
llm=llm, |
|
retriever=store.as_retriever(k=4, callbacks=[callback]), |
|
return_source_documents=True, |
|
verbose=True, |
|
) |
|
|