File size: 4,152 Bytes
45b9636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os, json
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_community.chat_models import ChatOllama
from langchain_community.llms import HuggingFaceHub
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.language_models.base import BaseLanguageModel
import logging
from langchain.globals import set_debug

# Enable verbose LangChain logging and write raw JSON lines to disk for analysis.
set_debug(True)
_lc_logger = logging.getLogger("langchain")
if not any(isinstance(h, logging.FileHandler) and getattr(h, "baseFilename", "").endswith("langchain_debug.jsonl") for h in _lc_logger.handlers):
    _fh = logging.FileHandler("langchain_debug.jsonl", mode="a", encoding="utf-8")
    _fh.setFormatter(logging.Formatter("%(message)s"))
    _lc_logger.addHandler(_fh)
    _lc_logger.setLevel(logging.DEBUG)

EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"


def load_index(index_dir: str = "data"):
    embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
    store = FAISS.load_local(index_dir, embeddings, allow_dangerous_deserialization=True)
    with open(os.path.join(index_dir, "segments.json")) as f:
        segments = json.load(f)
    return store, segments


class JSONLCallbackHandler(BaseCallbackHandler):
    """Write simple LangChain events to a JSONL file so UI can display them."""
    def __init__(self, path: str = "langchain_debug.jsonl"):
        self.path = path
        # Clear previous logs
        open(self.path, "w").close()

    def _write(self, record):
        import json, time
        record["ts"] = time.time()
        with open(self.path, "a", encoding="utf-8") as f:
            f.write(json.dumps(record) + "\n")

    def on_chain_start(self, serialized, inputs, **kwargs):
        self._write({"event": "chain_start", "name": serialized.get("name"), "inputs": inputs})

    def on_chain_end(self, outputs, **kwargs):
        self._write({"event": "chain_end", "outputs": outputs})

    def on_llm_start(self, serialized, prompts, **kwargs):
        self._write({"event": "llm_start", "prompts": prompts})

    def on_llm_end(self, response, **kwargs):
        self._write({"event": "llm_end", "response": str(response)})

    def on_retriever_end(self, documents, **kwargs):
        from langchain.docstore.document import Document
        preview = [doc.page_content[:200] if isinstance(doc, Document) else str(doc) for doc in documents]
        self._write({"event": "retriever_end", "documents": preview})


def get_model(model_name: str, hf_token: str = None, callbacks: list = None) -> BaseLanguageModel:
    """Return a model instance based on the model name.
    
    Args:
        model_name: Name of the model to use
        hf_token: Hugging Face API token (required for flan-t5-base)
        callbacks: List of callbacks to use
    """
    if model_name == "flan-t5-base":
        if not hf_token:
            raise ValueError(
                "Hugging Face API token is required for flan-t5-base. "
                "Please provide your Hugging Face token in the UI or use a local model."
            )
        return HuggingFaceHub(
            repo_id="google/flan-t5-base",
            huggingfacehub_api_token=hf_token,
            model_kwargs={"temperature": 0.1, "max_length": 512},
            callbacks=callbacks
        )
    else:
        return ChatOllama(model=model_name, callbacks=callbacks)


def build_chain(store, model_name: str = "phi3", hf_token: str = None):
    """Return a RetrievalQA chain using the specified model.
    
    Args:
        store: Vector store with document embeddings
        model_name: Name of the model to use
        hf_token: Hugging Face API token (required for flan-t5-base)
    """
    callback = JSONLCallbackHandler()
    llm = get_model(model_name, hf_token, [callback])
    return RetrievalQAWithSourcesChain.from_chain_type(
        llm=llm,
        retriever=store.as_retriever(k=4, callbacks=[callback]),
        return_source_documents=True,
        verbose=True,
    )