test
Browse files
api.py
CHANGED
@@ -10,40 +10,36 @@ os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
|
|
10 |
from fastapi import FastAPI
|
11 |
from fastapi.staticfiles import StaticFiles
|
12 |
from langchain_huggingface import HuggingFaceEmbeddings
|
13 |
-
from transformers import
|
|
|
14 |
from langchain_community.llms import HuggingFacePipeline
|
15 |
from qdrant_client import QdrantClient
|
16 |
from langchain_qdrant import QdrantVectorStore
|
17 |
from pydantic import BaseModel
|
18 |
from langchain.chains import RetrievalQA
|
19 |
-
from langchain.schema import Document
|
20 |
import time
|
21 |
-
import torch
|
22 |
|
|
|
23 |
model = None
|
24 |
tokenizer = None
|
25 |
-
|
26 |
embed_model = None
|
27 |
qdrant = None
|
28 |
-
model_name_hf = None
|
29 |
-
text_generation_pipeline = None
|
30 |
-
qa_pipeline = None
|
31 |
|
32 |
class Item(BaseModel):
|
33 |
query: str
|
34 |
|
35 |
app = FastAPI()
|
36 |
-
# app.mount("/TestFolder", StaticFiles(directory="./TestFolder"), name="TestFolder")
|
37 |
|
38 |
@app.on_event("startup")
|
39 |
async def startup_event():
|
40 |
-
global model, tokenizer,
|
41 |
-
|
42 |
-
print("π Loading model....")
|
43 |
|
44 |
-
|
45 |
start_time = time.perf_counter()
|
46 |
|
|
|
|
|
47 |
embed_model = HuggingFaceEmbeddings(
|
48 |
model_name=sentence_embedding_model_path,
|
49 |
model_kwargs={"device": "cpu"},
|
@@ -51,15 +47,18 @@ async def startup_event():
|
|
51 |
cache_folder=hf_cache_dir,
|
52 |
)
|
53 |
|
|
|
54 |
try:
|
55 |
qdrant_client = QdrantClient(path="qdrant/")
|
56 |
qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
|
57 |
except Exception as e:
|
58 |
print(f"β Error initializing Qdrant: {e}")
|
59 |
|
|
|
60 |
model_path = "distilbert-base-cased-distilled-squad"
|
61 |
model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir=hf_cache_dir)
|
62 |
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=hf_cache_dir)
|
|
|
63 |
qa_pipeline = pipeline(
|
64 |
"question-answering",
|
65 |
model=model,
|
@@ -68,68 +67,65 @@ async def startup_event():
|
|
68 |
)
|
69 |
|
70 |
end_time = time.perf_counter()
|
71 |
-
print(f"β
|
72 |
|
73 |
-
app.on_event("shutdown")
|
74 |
async def shutdown_event():
|
75 |
-
global model, tokenizer,
|
76 |
print("πͺ Shutting down the API and releasing model memory.")
|
77 |
-
del model, tokenizer,
|
78 |
-
|
79 |
|
80 |
@app.get("/")
|
81 |
def read_root():
|
82 |
return {"message": "Welcome to FastAPI"}
|
83 |
|
84 |
@app.post("/search")
|
85 |
-
def search(
|
86 |
print("Search endpoint")
|
87 |
-
query =
|
88 |
|
89 |
search_result = qdrant.similarity_search(
|
90 |
query=query, k=10
|
91 |
)
|
92 |
-
|
93 |
-
list_res = [
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
return list_res
|
99 |
|
100 |
@app.post("/ask_localai")
|
101 |
async def ask_localai(item: Item):
|
102 |
query = item.query
|
103 |
|
104 |
-
search_result = qdrant.similarity_search(query=query, k=3)
|
105 |
-
if not search_result:
|
106 |
-
return {"error": "No relevant results found for the query."}
|
107 |
-
|
108 |
-
context = " ".join([res.page_content for res in search_result])
|
109 |
-
if not context.strip():
|
110 |
-
return {"error": "No relevant context found."}
|
111 |
-
|
112 |
try:
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
)
|
119 |
-
|
120 |
-
answer = qa_result["answer"]
|
121 |
-
|
122 |
return {
|
123 |
"question": query,
|
124 |
-
"answer": answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
}
|
|
|
126 |
except Exception as e:
|
127 |
-
return {"error":
|
128 |
-
|
129 |
-
@app.get("/items/{item_id}")
|
130 |
-
def read_item(item_id: int, q: str = None):
|
131 |
-
return {"item_id": item_id, "q": q}
|
132 |
-
|
133 |
-
@app.post("/items/")
|
134 |
-
def create_item(item: Item):
|
135 |
-
return {"item": item, "total_price": item.price + (item.tax or 0)}
|
|
|
10 |
from fastapi import FastAPI
|
11 |
from fastapi.staticfiles import StaticFiles
|
12 |
from langchain_huggingface import HuggingFaceEmbeddings
|
13 |
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
14 |
+
import torch
|
15 |
from langchain_community.llms import HuggingFacePipeline
|
16 |
from qdrant_client import QdrantClient
|
17 |
from langchain_qdrant import QdrantVectorStore
|
18 |
from pydantic import BaseModel
|
19 |
from langchain.chains import RetrievalQA
|
|
|
20 |
import time
|
|
|
21 |
|
22 |
+
# Global variables
|
23 |
model = None
|
24 |
tokenizer = None
|
25 |
+
qa_pipeline = None
|
26 |
embed_model = None
|
27 |
qdrant = None
|
|
|
|
|
|
|
28 |
|
29 |
class Item(BaseModel):
|
30 |
query: str
|
31 |
|
32 |
app = FastAPI()
|
|
|
33 |
|
34 |
@app.on_event("startup")
|
35 |
async def startup_event():
|
36 |
+
global model, tokenizer, qa_pipeline, embed_model, qdrant
|
|
|
|
|
37 |
|
38 |
+
print("π Loading models....")
|
39 |
start_time = time.perf_counter()
|
40 |
|
41 |
+
# Load embedding model
|
42 |
+
sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
43 |
embed_model = HuggingFaceEmbeddings(
|
44 |
model_name=sentence_embedding_model_path,
|
45 |
model_kwargs={"device": "cpu"},
|
|
|
47 |
cache_folder=hf_cache_dir,
|
48 |
)
|
49 |
|
50 |
+
# Initialize Qdrant
|
51 |
try:
|
52 |
qdrant_client = QdrantClient(path="qdrant/")
|
53 |
qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
|
54 |
except Exception as e:
|
55 |
print(f"β Error initializing Qdrant: {e}")
|
56 |
|
57 |
+
# Load QA model
|
58 |
model_path = "distilbert-base-cased-distilled-squad"
|
59 |
model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir=hf_cache_dir)
|
60 |
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=hf_cache_dir)
|
61 |
+
|
62 |
qa_pipeline = pipeline(
|
63 |
"question-answering",
|
64 |
model=model,
|
|
|
67 |
)
|
68 |
|
69 |
end_time = time.perf_counter()
|
70 |
+
print(f"β
Models loaded successfully in {end_time - start_time:.2f} seconds.")
|
71 |
|
72 |
+
@app.on_event("shutdown")
|
73 |
async def shutdown_event():
|
74 |
+
global model, tokenizer, qa_pipeline, embed_model, qdrant
|
75 |
print("πͺ Shutting down the API and releasing model memory.")
|
76 |
+
del model, tokenizer, qa_pipeline, embed_model, qdrant
|
|
|
77 |
|
78 |
@app.get("/")
|
79 |
def read_root():
|
80 |
return {"message": "Welcome to FastAPI"}
|
81 |
|
82 |
@app.post("/search")
|
83 |
+
def search(item: Item):
|
84 |
print("Search endpoint")
|
85 |
+
query = item.query
|
86 |
|
87 |
search_result = qdrant.similarity_search(
|
88 |
query=query, k=10
|
89 |
)
|
90 |
+
|
91 |
+
list_res = [
|
92 |
+
{"id": i, "path": res.metadata.get("path"), "content": res.page_content}
|
93 |
+
for i, res in enumerate(search_result)
|
94 |
+
]
|
95 |
+
|
96 |
return list_res
|
97 |
|
98 |
@app.post("/ask_localai")
|
99 |
async def ask_localai(item: Item):
|
100 |
query = item.query
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
try:
|
103 |
+
# First, get relevant documents
|
104 |
+
docs = qdrant.similarity_search(query, k=3)
|
105 |
+
|
106 |
+
# Combine the documents into a single context
|
107 |
+
context = " ".join([doc.page_content for doc in docs])
|
108 |
+
|
109 |
+
# Use the QA pipeline directly
|
110 |
+
answer = qa_pipeline(
|
111 |
+
question=query,
|
112 |
+
context=context,
|
113 |
+
max_length=512,
|
114 |
+
max_answer_length=50,
|
115 |
+
handle_long_sequences=True
|
116 |
)
|
117 |
+
|
|
|
|
|
118 |
return {
|
119 |
"question": query,
|
120 |
+
"answer": answer["answer"],
|
121 |
+
"confidence": answer["score"],
|
122 |
+
"source_documents": [
|
123 |
+
{
|
124 |
+
"content": doc.page_content[:1000],
|
125 |
+
"metadata": doc.metadata
|
126 |
+
} for doc in docs
|
127 |
+
]
|
128 |
}
|
129 |
+
|
130 |
except Exception as e:
|
131 |
+
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|