Spaces:
Running
Running
sravan
commited on
Commit
·
ae692a1
1
Parent(s):
bd9af43
first working api's
Browse files- callbacks.py +1 -1
- chains.py +33 -11
- data_indexing.py +1 -1
- main.py +29 -9
- prompts.py +15 -4
- test.db +0 -0
callbacks.py
CHANGED
@@ -16,7 +16,7 @@ class LogResponseCallback(BaseCallbackHandler):
|
|
16 |
# TODO: The function on_llm_end is going to be called when the LLM stops sending
|
17 |
# the response. Use the crud.add_message function to capture that response.
|
18 |
type = 'AI'
|
19 |
-
user_data = crud.get_or_create(db, self.user_request.username)
|
20 |
user_id = user_data.user_id
|
21 |
timestamp = datetime.now()
|
22 |
message = outputs.generations[0][0].text # answer from the prompt message
|
|
|
16 |
# TODO: The function on_llm_end is going to be called when the LLM stops sending
|
17 |
# the response. Use the crud.add_message function to capture that response.
|
18 |
type = 'AI'
|
19 |
+
user_data = crud.get_or_create(self.db, self.user_request.username)
|
20 |
user_id = user_data.user_id
|
21 |
timestamp = datetime.now()
|
22 |
message = outputs.generations[0][0].text # answer from the prompt message
|
chains.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from langchain_huggingface import HuggingFaceEndpoint
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
|
5 |
import schemas
|
@@ -17,19 +17,41 @@ from transformers import AutoTokenizer
|
|
17 |
|
18 |
data_indexer = DataIndexer()
|
19 |
|
20 |
-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3
|
21 |
-
|
22 |
-
llm = HuggingFaceEndpoint(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
streaming=True,
|
29 |
-
|
30 |
-
|
|
|
|
|
31 |
)
|
32 |
|
|
|
|
|
33 |
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
34 |
|
35 |
# TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
|
|
1 |
import os
|
2 |
+
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
|
5 |
import schemas
|
|
|
17 |
|
18 |
data_indexer = DataIndexer()
|
19 |
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
21 |
+
|
22 |
+
# llm = HuggingFaceEndpoint(
|
23 |
+
# # repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
24 |
+
# repo_id="deepseek-ai/DeepSeek-R1-0528",
|
25 |
+
# huggingfacehub_api_token=os.environ['HF_TOKEN'],
|
26 |
+
# max_new_tokens=512,
|
27 |
+
# stop_sequences=["<|eot_id|>"],
|
28 |
+
# streaming=True,
|
29 |
+
# # task="conversational",
|
30 |
+
# task="text-generation",
|
31 |
+
# # provider='novita',
|
32 |
+
# # temperature=0.7,
|
33 |
+
# )
|
34 |
+
# llm = HuggingFaceEndpoint(
|
35 |
+
# repo_id="deepseek-ai/DeepSeek-R1-0528",
|
36 |
+
# max_new_tokens=512,
|
37 |
+
# temperature=0.5,
|
38 |
+
# huggingfacehub_api_token=os.environ['HF_TOKEN'],
|
39 |
+
# provider="auto",
|
40 |
+
# )
|
41 |
+
|
42 |
+
llm_endpoint = HuggingFaceEndpoint(
|
43 |
+
repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
44 |
+
task="text-generation",
|
45 |
+
max_new_tokens=100,
|
46 |
streaming=True,
|
47 |
+
do_sample=False,
|
48 |
+
temperature=0.9,
|
49 |
+
repetition_penalty=1.03,
|
50 |
+
provider="auto",
|
51 |
)
|
52 |
|
53 |
+
llm = ChatHuggingFace(llm=llm_endpoint)
|
54 |
+
|
55 |
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
56 |
|
57 |
# TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
data_indexing.py
CHANGED
@@ -80,7 +80,7 @@ class DataIndexer:
|
|
80 |
# values = None
|
81 |
|
82 |
# TODO: create a list of unique identifiers for each element in the batch with the uuid package.
|
83 |
-
vector_ids = [uuid.uuid4() for _ in batch]
|
84 |
|
85 |
# TODO: create a list of dictionaries representing the metadata. Capture the text data
|
86 |
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
|
|
80 |
# values = None
|
81 |
|
82 |
# TODO: create a list of unique identifiers for each element in the batch with the uuid package.
|
83 |
+
vector_ids = [str(uuid.uuid4()) for _ in batch]
|
84 |
|
85 |
# TODO: create a list of dictionaries representing the metadata. Capture the text data
|
86 |
# with the "text" key, and make sure to capture the rest of the doc.metadata.
|
main.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from langchain_core.runnables import Runnable
|
2 |
from langchain_core.callbacks import BaseCallbackHandler
|
3 |
-
from fastapi import FastAPI, Request, Depends
|
4 |
from sse_starlette.sse import EventSourceResponse
|
5 |
from langserve.serialization import WellKnownLCSerializer
|
6 |
from typing import List
|
@@ -36,13 +36,27 @@ def get_db():
|
|
36 |
db.close()
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
40 |
-
for
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
yield {"event": "end"}
|
44 |
|
45 |
|
|
|
|
|
46 |
@app.post("/simple/stream")
|
47 |
async def simple_stream(request: Request):
|
48 |
data = await request.json()
|
@@ -56,11 +70,13 @@ async def formatted_stream(request: Request):
|
|
56 |
try:
|
57 |
data = await request.json()
|
58 |
user_question = schemas.UserQuestion(**data['input'])
|
59 |
-
|
60 |
generate_stream(
|
61 |
input_data = user_question,
|
62 |
runnable = formatted_chain )
|
63 |
)
|
|
|
|
|
64 |
except ValueError as e:
|
65 |
raise HTTPException(status_code=400, detail=str(e))
|
66 |
except Exception as e:
|
@@ -103,7 +119,11 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
|
|
103 |
|
104 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
105 |
# chat history contains: [{ message, type, timestamp}]
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
# raise NotImplemented
|
108 |
|
109 |
|
@@ -119,7 +139,7 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
119 |
data = await request.json()
|
120 |
user_request = schemas.UserRequest(**data['input'])
|
121 |
messages = crud.get_user_chat_history(db, user_request.username)
|
122 |
-
chat_history =
|
123 |
|
124 |
history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question)
|
125 |
## adding messgae to message database
|
@@ -149,14 +169,14 @@ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
|
|
149 |
# - We create an instance of HistoryInput by using format_chat_history.
|
150 |
# - We use the history input within the filtered rag chain.
|
151 |
data = await request.json()
|
152 |
-
user_request = models.UserRequest(
|
153 |
|
154 |
messages = db.Query(
|
155 |
Message.message,
|
156 |
Message.type,
|
157 |
Message.timestamp
|
158 |
).filter(Message.user_id == user_request.username)
|
159 |
-
chat_history =
|
160 |
|
161 |
history_input = schemas.HistoryInput(
|
162 |
chat_history = format_chat_history(chat_history),
|
|
|
1 |
from langchain_core.runnables import Runnable
|
2 |
from langchain_core.callbacks import BaseCallbackHandler
|
3 |
+
from fastapi import FastAPI, Request, Depends, HTTPException
|
4 |
from sse_starlette.sse import EventSourceResponse
|
5 |
from langserve.serialization import WellKnownLCSerializer
|
6 |
from typing import List
|
|
|
36 |
db.close()
|
37 |
|
38 |
|
39 |
+
# async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
40 |
+
# for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
|
41 |
+
# data = WellKnownLCSerializer().dumps(output).decode("utf-8")
|
42 |
+
# yield {'data': data, "event": "data"}
|
43 |
+
# yield {"event": "end"}
|
44 |
+
|
45 |
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
46 |
+
for chunk in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
|
47 |
+
# ChatHuggingFace returns message chunks with content attribute
|
48 |
+
if hasattr(chunk, 'content'):
|
49 |
+
content = chunk.content
|
50 |
+
else:
|
51 |
+
content = str(chunk)
|
52 |
+
|
53 |
+
if content: # Only yield non-empty content
|
54 |
+
yield {'data': content, "event": "data"}
|
55 |
yield {"event": "end"}
|
56 |
|
57 |
|
58 |
+
|
59 |
+
|
60 |
@app.post("/simple/stream")
|
61 |
async def simple_stream(request: Request):
|
62 |
data = await request.json()
|
|
|
70 |
try:
|
71 |
data = await request.json()
|
72 |
user_question = schemas.UserQuestion(**data['input'])
|
73 |
+
output = EventSourceResponse(
|
74 |
generate_stream(
|
75 |
input_data = user_question,
|
76 |
runnable = formatted_chain )
|
77 |
)
|
78 |
+
# print(output.generations[0][0].text)
|
79 |
+
return output
|
80 |
except ValueError as e:
|
81 |
raise HTTPException(status_code=400, detail=str(e))
|
82 |
except Exception as e:
|
|
|
119 |
|
120 |
_ = crud.add_message(db,add_message, username = user_request.username)
|
121 |
# chat history contains: [{ message, type, timestamp}]
|
122 |
+
|
123 |
+
output = EventSourceResponse(generate_stream(history_input, history_chain))
|
124 |
+
LogResponseCallback.on_llm_end(outputs = output)
|
125 |
+
|
126 |
+
return output
|
127 |
# raise NotImplemented
|
128 |
|
129 |
|
|
|
139 |
data = await request.json()
|
140 |
user_request = schemas.UserRequest(**data['input'])
|
141 |
messages = crud.get_user_chat_history(db, user_request.username)
|
142 |
+
chat_history = messages
|
143 |
|
144 |
history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question)
|
145 |
## adding messgae to message database
|
|
|
169 |
# - We create an instance of HistoryInput by using format_chat_history.
|
170 |
# - We use the history input within the filtered rag chain.
|
171 |
data = await request.json()
|
172 |
+
user_request = models.UserRequest(**dat['input'])
|
173 |
|
174 |
messages = db.Query(
|
175 |
Message.message,
|
176 |
Message.type,
|
177 |
Message.timestamp
|
178 |
).filter(Message.user_id == user_request.username)
|
179 |
+
chat_history = messages
|
180 |
|
181 |
history_input = schemas.HistoryInput(
|
182 |
chat_history = format_chat_history(chat_history),
|
prompts.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from langchain_core.prompts import PromptTemplate
|
2 |
from typing import List
|
3 |
import models
|
4 |
-
|
5 |
|
6 |
def format_prompt(prompt) -> PromptTemplate:
|
7 |
# TODO: format the input prompt by using the model specific instruction template
|
@@ -12,10 +12,20 @@ def format_prompt(prompt) -> PromptTemplate:
|
|
12 |
{prompt}<|eot_id|>
|
13 |
<|start_header_id|>assistant<|end_header_id|>
|
14 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
prompt_template = PromptTemplate(
|
17 |
# input_variables=["question"], the variables will be auto detected by langchain package
|
18 |
-
|
19 |
)
|
20 |
# TODO: return a langchain PromptTemplate
|
21 |
return prompt_template
|
@@ -26,7 +36,7 @@ def format_chat_history(messages: List[models.Message]):
|
|
26 |
# the list of Message into a text of chat history.
|
27 |
chat_history = ""
|
28 |
for msg in messages:
|
29 |
-
chat_history += msg
|
30 |
chat_history += "\n---\n"
|
31 |
# combined all messages from the list for sending it to the model prompt.
|
32 |
return chat_history
|
@@ -48,6 +58,7 @@ def format_context(docs: List[str]):
|
|
48 |
|
49 |
raw_prompt = "{question}"
|
50 |
|
|
|
51 |
# TODO: Create the history_prompt prompt that will capture the question and the conversation history.
|
52 |
# The history_prompt needs a {chat_history} placeholder and a {question} placeholder.
|
53 |
history_prompt: str = """
|
|
|
1 |
from langchain_core.prompts import PromptTemplate
|
2 |
from typing import List
|
3 |
import models
|
4 |
+
from transformers import AutoTokenizer
|
5 |
|
6 |
def format_prompt(prompt) -> PromptTemplate:
|
7 |
# TODO: format the input prompt by using the model specific instruction template
|
|
|
12 |
{prompt}<|eot_id|>
|
13 |
<|start_header_id|>assistant<|end_header_id|>
|
14 |
"""
|
15 |
+
raw_template = [
|
16 |
+
{"role": "system", "content":"You are a helpful assistant." },
|
17 |
+
{"role": "user", "content": "{{prompt}}"},
|
18 |
+
]
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
20 |
+
formatted_template = tokenizer.apply_chat_template(
|
21 |
+
raw_template,
|
22 |
+
tokenize=False,
|
23 |
+
add_generation_prompt=True
|
24 |
+
)
|
25 |
|
26 |
+
prompt_template = PromptTemplate.from_template(
|
27 |
# input_variables=["question"], the variables will be auto detected by langchain package
|
28 |
+
formatted_template
|
29 |
)
|
30 |
# TODO: return a langchain PromptTemplate
|
31 |
return prompt_template
|
|
|
36 |
# the list of Message into a text of chat history.
|
37 |
chat_history = ""
|
38 |
for msg in messages:
|
39 |
+
chat_history += '{}:{}'.format(msg.type, msg.message)
|
40 |
chat_history += "\n---\n"
|
41 |
# combined all messages from the list for sending it to the model prompt.
|
42 |
return chat_history
|
|
|
58 |
|
59 |
raw_prompt = "{question}"
|
60 |
|
61 |
+
|
62 |
# TODO: Create the history_prompt prompt that will capture the question and the conversation history.
|
63 |
# The history_prompt needs a {chat_history} placeholder and a {question} placeholder.
|
64 |
history_prompt: str = """
|
test.db
CHANGED
Binary files a/test.db and b/test.db differ
|
|