File size: 4,155 Bytes
741f393 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForQuestionAnswering
from langchain_community.llms import HuggingFacePipeline
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
import os
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain.schema import Document
import time
import torch
model = None
tokenizer = None
dolly_pipeline_hf = None
embed_model = None
qdrant = None
model_name_hf = None
text_generation_pipeline = None
qa_pipeline = None
class Item(BaseModel):
query: str
app = FastAPI()
app.mount("/TestFolder", StaticFiles(directory="./TestFolder"), name="TestFolder")
os.makedirs("./cache", exist_ok=True)
os.makedirs("./offload", exist_ok=True)
os.makedirs("./models", exist_ok=True)
@app.on_event("startup")
async def startup_event():
global model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
print("π Loading model....")
sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
start_time = time.perf_counter()
embed_model = HuggingFaceEmbeddings(
model_name=sentence_embedding_model_path,
cache_folder="./models",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
try:
qdrant_client = QdrantClient(path="qdrant/")
qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
except Exception as e:
print(f"β Error initializing Qdrant: {e}")
model_path = "distilbert-base-cased-distilled-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir="./models")
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="./models")
qa_pipeline = pipeline(
"question-answering",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
end_time = time.perf_counter()
print(f"β
Dolly model loaded successfully in {end_time - start_time:.2f} seconds.")
app.on_event("shutdown")
async def shutdown_event():
global model, tokenizer, dolly_pipeline_hf
print("πͺ Shutting down the API and releasing model memory.")
del model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
@app.get("/")
def read_root():
return {"message": "Welcome to FastAPI"}
@app.post("/search")
def search(Item:Item):
print("Search endpoint")
query = Item.query
search_result = qdrant.similarity_search(
query=query, k=10
)
i = 0
list_res = []
for res in search_result:
list_res.append({"id":i,"path":res.metadata.get("path"),"content":res.page_content})
return list_res
@app.post("/ask_localai")
async def ask_localai(item: Item):
query = item.query
search_result = qdrant.similarity_search(query=query, k=3)
if not search_result:
return {"error": "No relevant results found for the query."}
context = " ".join([res.page_content for res in search_result])
if not context.strip():
return {"error": "No relevant context found."}
try:
prompt = (
f"Context: {context}\n\n"
f"Question: {query}\n"
f"Answer concisely and only based on the context provided. Do not repeat the context or the question.\n"
f"Answer:"
)
qa_result = qa_pipeline(question=query, context=context)
answer = qa_result["answer"]
return {
"question": query,
"answer": answer
}
except Exception as e:
return {"error": "Failed to generate an answer."}
@app.get("/items/{item_id}")
def read_item(item_id: int, q: str = None):
return {"item_id": item_id, "q": q}
@app.post("/items/")
def create_item(item: Item):
return {"item": item, "total_price": item.price + (item.tax or 0)}
|