fpadron commited on
Commit
d2b4768
Β·
1 Parent(s): 11f746b
Files changed (1) hide show
  1. api.py +47 -51
api.py CHANGED
@@ -10,40 +10,36 @@ os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
10
  from fastapi import FastAPI
11
  from fastapi.staticfiles import StaticFiles
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForQuestionAnswering
 
14
  from langchain_community.llms import HuggingFacePipeline
15
  from qdrant_client import QdrantClient
16
  from langchain_qdrant import QdrantVectorStore
17
  from pydantic import BaseModel
18
  from langchain.chains import RetrievalQA
19
- from langchain.schema import Document
20
  import time
21
- import torch
22
 
 
23
  model = None
24
  tokenizer = None
25
- dolly_pipeline_hf = None
26
  embed_model = None
27
  qdrant = None
28
- model_name_hf = None
29
- text_generation_pipeline = None
30
- qa_pipeline = None
31
 
32
  class Item(BaseModel):
33
  query: str
34
 
35
  app = FastAPI()
36
- # app.mount("/TestFolder", StaticFiles(directory="./TestFolder"), name="TestFolder")
37
 
38
  @app.on_event("startup")
39
  async def startup_event():
40
- global model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
41
-
42
- print("πŸš€ Loading model....")
43
 
44
- sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
45
  start_time = time.perf_counter()
46
 
 
 
47
  embed_model = HuggingFaceEmbeddings(
48
  model_name=sentence_embedding_model_path,
49
  model_kwargs={"device": "cpu"},
@@ -51,15 +47,18 @@ async def startup_event():
51
  cache_folder=hf_cache_dir,
52
  )
53
 
 
54
  try:
55
  qdrant_client = QdrantClient(path="qdrant/")
56
  qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
57
  except Exception as e:
58
  print(f"❌ Error initializing Qdrant: {e}")
59
 
 
60
  model_path = "distilbert-base-cased-distilled-squad"
61
  model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir=hf_cache_dir)
62
  tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=hf_cache_dir)
 
63
  qa_pipeline = pipeline(
64
  "question-answering",
65
  model=model,
@@ -68,68 +67,65 @@ async def startup_event():
68
  )
69
 
70
  end_time = time.perf_counter()
71
- print(f"βœ… Dolly model loaded successfully in {end_time - start_time:.2f} seconds.")
72
 
73
- app.on_event("shutdown")
74
  async def shutdown_event():
75
- global model, tokenizer, dolly_pipeline_hf
76
  print("πŸšͺ Shutting down the API and releasing model memory.")
77
- del model, tokenizer, dolly_pipeline_hf, embed_model, qdrant, model_name_hf, text_generation_pipeline, qa_pipeline
78
-
79
 
80
  @app.get("/")
81
  def read_root():
82
  return {"message": "Welcome to FastAPI"}
83
 
84
  @app.post("/search")
85
- def search(Item:Item):
86
  print("Search endpoint")
87
- query = Item.query
88
 
89
  search_result = qdrant.similarity_search(
90
  query=query, k=10
91
  )
92
- i = 0
93
- list_res = []
94
- for res in search_result:
95
- list_res.append({"id":i,"path":res.metadata.get("path"),"content":res.page_content})
96
-
97
-
98
  return list_res
99
 
100
  @app.post("/ask_localai")
101
  async def ask_localai(item: Item):
102
  query = item.query
103
 
104
- search_result = qdrant.similarity_search(query=query, k=3)
105
- if not search_result:
106
- return {"error": "No relevant results found for the query."}
107
-
108
- context = " ".join([res.page_content for res in search_result])
109
- if not context.strip():
110
- return {"error": "No relevant context found."}
111
-
112
  try:
113
- prompt = (
114
- f"Context: {context}\n\n"
115
- f"Question: {query}\n"
116
- f"Answer concisely and only based on the context provided. Do not repeat the context or the question.\n"
117
- f"Answer:"
 
 
 
 
 
 
 
 
118
  )
119
- qa_result = qa_pipeline(question=query, context=context)
120
- answer = qa_result["answer"]
121
-
122
  return {
123
  "question": query,
124
- "answer": answer
 
 
 
 
 
 
 
125
  }
 
126
  except Exception as e:
127
- return {"error": "Failed to generate an answer."}
128
-
129
- @app.get("/items/{item_id}")
130
- def read_item(item_id: int, q: str = None):
131
- return {"item_id": item_id, "q": q}
132
-
133
- @app.post("/items/")
134
- def create_item(item: Item):
135
- return {"item": item, "total_price": item.price + (item.tax or 0)}
 
10
  from fastapi import FastAPI
11
  from fastapi.staticfiles import StaticFiles
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
14
+ import torch
15
  from langchain_community.llms import HuggingFacePipeline
16
  from qdrant_client import QdrantClient
17
  from langchain_qdrant import QdrantVectorStore
18
  from pydantic import BaseModel
19
  from langchain.chains import RetrievalQA
 
20
  import time
 
21
 
22
+ # Global variables
23
  model = None
24
  tokenizer = None
25
+ qa_pipeline = None
26
  embed_model = None
27
  qdrant = None
 
 
 
28
 
29
  class Item(BaseModel):
30
  query: str
31
 
32
  app = FastAPI()
 
33
 
34
  @app.on_event("startup")
35
  async def startup_event():
36
+ global model, tokenizer, qa_pipeline, embed_model, qdrant
 
 
37
 
38
+ print("πŸš€ Loading models....")
39
  start_time = time.perf_counter()
40
 
41
+ # Load embedding model
42
+ sentence_embedding_model_path = "sentence-transformers/paraphrase-MiniLM-L6-v2"
43
  embed_model = HuggingFaceEmbeddings(
44
  model_name=sentence_embedding_model_path,
45
  model_kwargs={"device": "cpu"},
 
47
  cache_folder=hf_cache_dir,
48
  )
49
 
50
+ # Initialize Qdrant
51
  try:
52
  qdrant_client = QdrantClient(path="qdrant/")
53
  qdrant = QdrantVectorStore(qdrant_client, "MyCollection", embed_model, distance="Dot")
54
  except Exception as e:
55
  print(f"❌ Error initializing Qdrant: {e}")
56
 
57
+ # Load QA model
58
  model_path = "distilbert-base-cased-distilled-squad"
59
  model = AutoModelForQuestionAnswering.from_pretrained(model_path, cache_dir=hf_cache_dir)
60
  tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=hf_cache_dir)
61
+
62
  qa_pipeline = pipeline(
63
  "question-answering",
64
  model=model,
 
67
  )
68
 
69
  end_time = time.perf_counter()
70
+ print(f"βœ… Models loaded successfully in {end_time - start_time:.2f} seconds.")
71
 
72
+ @app.on_event("shutdown")
73
  async def shutdown_event():
74
+ global model, tokenizer, qa_pipeline, embed_model, qdrant
75
  print("πŸšͺ Shutting down the API and releasing model memory.")
76
+ del model, tokenizer, qa_pipeline, embed_model, qdrant
 
77
 
78
  @app.get("/")
79
  def read_root():
80
  return {"message": "Welcome to FastAPI"}
81
 
82
  @app.post("/search")
83
+ def search(item: Item):
84
  print("Search endpoint")
85
+ query = item.query
86
 
87
  search_result = qdrant.similarity_search(
88
  query=query, k=10
89
  )
90
+
91
+ list_res = [
92
+ {"id": i, "path": res.metadata.get("path"), "content": res.page_content}
93
+ for i, res in enumerate(search_result)
94
+ ]
95
+
96
  return list_res
97
 
98
  @app.post("/ask_localai")
99
  async def ask_localai(item: Item):
100
  query = item.query
101
 
 
 
 
 
 
 
 
 
102
  try:
103
+ # First, get relevant documents
104
+ docs = qdrant.similarity_search(query, k=3)
105
+
106
+ # Combine the documents into a single context
107
+ context = " ".join([doc.page_content for doc in docs])
108
+
109
+ # Use the QA pipeline directly
110
+ answer = qa_pipeline(
111
+ question=query,
112
+ context=context,
113
+ max_length=512,
114
+ max_answer_length=50,
115
+ handle_long_sequences=True
116
  )
117
+
 
 
118
  return {
119
  "question": query,
120
+ "answer": answer["answer"],
121
+ "confidence": answer["score"],
122
+ "source_documents": [
123
+ {
124
+ "content": doc.page_content[:1000],
125
+ "metadata": doc.metadata
126
+ } for doc in docs
127
+ ]
128
  }
129
+
130
  except Exception as e:
131
+ return {"error": str(e)}