api-space / api.py
fpadron's picture
initial commit
741f393
raw
history blame
4.16 kB
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForQuestionAnswering
from langchain_community.llms import HuggingFacePipeline
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
import os
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain.schema import Document
import time
import torch
model = None
tokenizer = None
dolly_pipeline_hf = None
embed_model = None
qdrant = None
model_name_hf = None
text_generation_pipeline = None
qa_pipeline = None
class Item(BaseModel):
query: str
app = FastAPI()
app.mount("/TestFolder", StaticFiles(directory="./TestFolder"), name="TestFolder")
os.makedirs("./cache", exist_ok=True)
os.makedirs("./offload", exist_ok=True)
os.makedirs("./models", exist_ok=True)
@app.on_event("startup")
async def startup_event():
global model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
print("πŸš€ Loading model....")
sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
start_time = time.perf_counter()
embed_model = HuggingFaceEmbeddings(
model_name=sentence_embedding_model_path,
cache_folder="./models",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
try:
qdrant_client = QdrantClient(path="qdrant/")
qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
except Exception as e:
print(f"❌ Error initializing Qdrant: {e}")
model_path = "distilbert-base-cased-distilled-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir="./models")
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="./models")
qa_pipeline = pipeline(
"question-answering",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
end_time = time.perf_counter()
print(f"βœ… Dolly model loaded successfully in {end_time - start_time:.2f} seconds.")
app.on_event("shutdown")
async def shutdown_event():
global model, tokenizer, dolly_pipeline_hf
print("πŸšͺ Shutting down the API and releasing model memory.")
del model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
@app.get("/")
def read_root():
return {"message": "Welcome to FastAPI"}
@app.post("/search")
def search(Item:Item):
print("Search endpoint")
query = Item.query
search_result = qdrant.similarity_search(
query=query, k=10
)
i = 0
list_res = []
for res in search_result:
list_res.append({"id":i,"path":res.metadata.get("path"),"content":res.page_content})
return list_res
@app.post("/ask_localai")
async def ask_localai(item: Item):
query = item.query
search_result = qdrant.similarity_search(query=query, k=3)
if not search_result:
return {"error": "No relevant results found for the query."}
context = " ".join([res.page_content for res in search_result])
if not context.strip():
return {"error": "No relevant context found."}
try:
prompt = (
f"Context: {context}\n\n"
f"Question: {query}\n"
f"Answer concisely and only based on the context provided. Do not repeat the context or the question.\n"
f"Answer:"
)
qa_result = qa_pipeline(question=query, context=context)
answer = qa_result["answer"]
return {
"question": query,
"answer": answer
}
except Exception as e:
return {"error": "Failed to generate an answer."}
@app.get("/items/{item_id}")
def read_item(item_id: int, q: str = None):
return {"item_id": item_id, "q": q}
@app.post("/items/")
def create_item(item: Item):
return {"item": item, "total_price": item.price + (item.tax or 0)}