Spaces:
Sleeping
Sleeping
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
from langchain.retrievers import EnsembleRetriever | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_openai import ChatOpenAI | |
import numpy as np | |
from sentence_transformers import CrossEncoder | |
from dotenv import load_dotenv | |
import streamlit as st | |
from datasets import load_dataset | |
import os | |
import pickle | |
import faiss | |
from langchain_community.docstore.in_memory import InMemoryDocstore # Add this import | |
import time | |
load_dotenv() | |
def get_vector_store(): | |
"""Load vectorstore from pre-computed embeddings""" | |
try: | |
# Load pre-computed data | |
if not os.path.exists('src/medical_embeddings.npy'): | |
raise FileNotFoundError("medical_embeddings.npy not found") | |
if not os.path.exists('src/medical_texts.pkl'): | |
raise FileNotFoundError("medical_texts.pkl not found") | |
print("π₯ Loading pre-computed embeddings...") | |
embeddings_array = np.load('src/medical_embeddings.npy') | |
with open('src/medical_texts.pkl', 'rb') as f: | |
texts = pickle.load(f) | |
print(f"β Loaded {len(embeddings_array)} pre-computed embeddings") | |
# Create FAISS index from pre-computed embeddings | |
dimension = embeddings_array.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings_array.astype('float32')) # type: ignore | |
# import os | |
# os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp' | |
# os.makedirs('/tmp', exist_ok=True) | |
# Create embedding function for new queries | |
embeddings_function = HuggingFaceEmbeddings( | |
model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" | |
) | |
# Create proper Document objects and InMemoryDocstore | |
documents_dict = {} | |
documents = [] | |
for i, text in enumerate(texts): | |
# Create Document objects with proper metadata | |
doc = Document( | |
page_content=text, | |
metadata={"doc_id": i, "type": "medical_qa"} | |
) | |
documents_dict[str(i)] = doc | |
documents.append(doc) | |
# Create proper docstore | |
docstore = InMemoryDocstore(documents_dict) | |
# Create index to docstore mapping | |
index_to_docstore_id = {i: str(i) for i in range(len(texts))} | |
# Create FAISS vectorstore with proper parameters | |
vectorstore = FAISS( | |
embedding_function=embeddings_function, | |
index=index, | |
docstore=docstore, | |
index_to_docstore_id=index_to_docstore_id | |
) | |
return vectorstore, documents | |
except FileNotFoundError as e: | |
print(f"β Pre-computed files not found: {e}") | |
print("π Falling back to creating embeddings...") | |
return None, None | |
except Exception as e: | |
print(f"β Error loading pre-computed embeddings: {e}") | |
print("π Falling back to creating embeddings...") | |
return None, None | |
def load_medical_system(): | |
"""Load the medical RAG system (cached for performance)""" | |
with st.spinner("π Loading medical knowledge base..."): | |
# Load dataset | |
# ds = load_dataset("keivalya/MedQuad-MedicalQnADataset") | |
# # Create documents | |
# documents = [] | |
# for i, item in enumerate(ds['train']): # type: ignore | |
# content = f"Question: {item['Question']}\nAnswer: {item['Answer']}" # type: ignore | |
# metadata = { | |
# "doc_id": i, | |
# "question": item['Question'], # type: ignore | |
# "answer": item['Answer'], # type: ignore | |
# "question_type": item['qtype'], # type: ignore | |
# "type": "qa_pair" | |
# } | |
# documents.append(Document(page_content=content, metadata=metadata)) | |
start = time.time() | |
# Try to load existing vectorstore | |
vectorstore, documents = get_vector_store() | |
end = time.time() | |
if vectorstore is None or documents is None: | |
st.error("β Could not load the vectorstore. Please ensure the embeddings and text files exist.") | |
st.stop() | |
total_time = end - start | |
st.success(f"β Loaded existing vectorstore in {total_time:.2f} seconds") | |
# Create retrievers | |
bm25_retriever = BM25Retriever.from_documents(documents) | |
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, vector_retriever], | |
weights=[0.3, 0.7] | |
) | |
# create LLM | |
openai_key = os.getenv("OPENAI_API_KEY") | |
if not openai_key: | |
st.error("β OpenAI API key not found! Please set it in your environment variables or .streamlit/secrets.toml") | |
st.stop() | |
llm = ChatOpenAI(temperature=0, api_key=openai_key) # type: ignore | |
# Create reranker | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
return documents, ensemble_retriever, llm, reranker |