File size: 4,068 Bytes
6b73bad 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 59d168b 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 6b73bad 741f393 d2b4768 741f393 7c7eadc 741f393 d2b4768 741f393 6b73bad d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 d2b4768 741f393 11f746b d2b4768 11f746b d2b4768 11f746b d2b4768 11f746b d2b4768 11f746b d2b4768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
# Set a writable directory for Hugging Face cache and environment variables
hf_cache_dir = "/tmp/huggingface_cache"
os.environ["HF_HOME"] = hf_cache_dir
os.environ["TRANSFORMERS_CACHE"] = os.path.join(hf_cache_dir, "transformers")
os.makedirs(hf_cache_dir, exist_ok=True)
os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
import torch
from langchain_community.llms import HuggingFacePipeline
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
from pydantic import BaseModel
from langchain.chains import RetrievalQA
import time
# Global variables
model = None
tokenizer = None
qa_pipeline = None
embed_model = None
qdrant = None
class Item(BaseModel):
query: str
app = FastAPI()
# Mount static files from TestFolder
app.mount("/files", StaticFiles(directory="TestFolder"), name="files")
@app.on_event("startup")
async def startup_event():
global model, tokenizer, qa_pipeline, embed_model, qdrant
print("π Loading models....")
start_time = time.perf_counter()
# Load embedding model
sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
embed_model = HuggingFaceEmbeddings(
model_name=sentence_embedding_model_path,
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
cache_folder=hf_cache_dir,
)
# Initialize Qdrant
try:
qdrant_client = QdrantClient(path="qdrant/")
qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
except Exception as e:
print(f"β Error initializing Qdrant: {e}")
# Load QA model
model_path = "distilbert-base-cased-distilled-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir=hf_cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=hf_cache_dir)
qa_pipeline = pipeline(
"question-answering",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
end_time = time.perf_counter()
print(f"β
Models loaded successfully in {end_time - start_time:.2f} seconds.")
@app.on_event("shutdown")
async def shutdown_event():
global model, tokenizer, qa_pipeline, embed_model, qdrant
print("πͺ Shutting down the API and releasing model memory.")
del model, tokenizer, qa_pipeline, embed_model, qdrant
@app.get("/")
def read_root():
return {"message": "Welcome to FastAPI"}
@app.post("/search")
def search(item: Item):
print("Search endpoint")
query = item.query
search_result = qdrant.similarity_search(
query=query, k=10
)
list_res = [
{"id": i, "path": res.metadata.get("path"), "content": res.page_content}
for i, res in enumerate(search_result)
]
return list_res
@app.post("/ask_localai")
async def ask_localai(item: Item):
query = item.query
try:
# First, get relevant documents
docs = qdrant.similarity_search(query, k=3)
# Combine the documents into a single context
context = " ".join([doc.page_content for doc in docs])
# Use the QA pipeline directly
answer = qa_pipeline(
question=query,
context=context,
max_length=512,
max_answer_length=50,
handle_long_sequences=True
)
return {
"question": query,
"answer": answer["answer"],
"confidence": answer["score"],
"source_documents": [
{
"content": doc.page_content[:1000],
"metadata": doc.metadata
} for doc in docs
]
}
except Exception as e:
return {"error": str(e)}
|