File size: 4,155 Bytes
741f393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForQuestionAnswering
from langchain_community.llms import HuggingFacePipeline
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
import os
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain.schema import Document
import time
import torch

model = None
tokenizer = None
dolly_pipeline_hf = None
embed_model = None
qdrant = None
model_name_hf = None
text_generation_pipeline = None
qa_pipeline = None

class Item(BaseModel):
    query: str

app = FastAPI()

app.mount("/TestFolder", StaticFiles(directory="./TestFolder"), name="TestFolder")
os.makedirs("./cache", exist_ok=True)
os.makedirs("./offload", exist_ok=True)
os.makedirs("./models", exist_ok=True)

@app.on_event("startup")
async def startup_event():
    global model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
    
    print("πŸš€ Loading model....")
    
    sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
    start_time = time.perf_counter()
    
    embed_model = HuggingFaceEmbeddings(
    model_name=sentence_embedding_model_path,
    cache_folder="./models",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True},
    )

    try:
        qdrant_client = QdrantClient(path="qdrant/")
        qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
    except Exception as e:
        print(f"❌ Error initializing Qdrant: {e}")

    model_path = "distilbert-base-cased-distilled-squad"
    model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir="./models")
    tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="./models")
    qa_pipeline = pipeline(
        "question-answering",
        model=model,
        tokenizer=tokenizer,
        device=0 if torch.cuda.is_available() else -1
    )

    end_time = time.perf_counter()
    print(f"βœ… Dolly model loaded successfully in {end_time - start_time:.2f} seconds.")

app.on_event("shutdown")
async def shutdown_event():
    global model, tokenizer, dolly_pipeline_hf
    print("πŸšͺ Shutting down the API and releasing model memory.")
    del model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline


@app.get("/")
def read_root():
    return {"message": "Welcome to FastAPI"}

@app.post("/search")
def search(Item:Item):
    print("Search endpoint")
    query = Item.query
    
    search_result = qdrant.similarity_search(
        query=query, k=10
    )
    i = 0
    list_res = []
    for res in search_result:
        list_res.append({"id":i,"path":res.metadata.get("path"),"content":res.page_content})


    return list_res

@app.post("/ask_localai")
async def ask_localai(item: Item):
    query = item.query
    
    search_result = qdrant.similarity_search(query=query, k=3)
    if not search_result:
        return {"error": "No relevant results found for the query."}

    context = " ".join([res.page_content for res in search_result])
    if not context.strip():
        return {"error": "No relevant context found."}

    try:
        prompt = (
            f"Context: {context}\n\n"
            f"Question: {query}\n"
            f"Answer concisely and only based on the context provided. Do not repeat the context or the question.\n"
            f"Answer:"
        )
        qa_result = qa_pipeline(question=query, context=context)
        answer = qa_result["answer"]

        return {
            "question": query,
            "answer": answer
        }
    except Exception as e:
        return {"error": "Failed to generate an answer."}

@app.get("/items/{item_id}")
def read_item(item_id: int, q: str = None):
    return {"item_id": item_id, "q": q}

@app.post("/items/")
def create_item(item: Item):
    return {"item": item, "total_price": item.price + (item.tax or 0)}