sravan commited on
Commit
ae692a1
·
1 Parent(s): bd9af43

first working api's

Browse files
Files changed (6) hide show
  1. callbacks.py +1 -1
  2. chains.py +33 -11
  3. data_indexing.py +1 -1
  4. main.py +29 -9
  5. prompts.py +15 -4
  6. test.db +0 -0
callbacks.py CHANGED
@@ -16,7 +16,7 @@ class LogResponseCallback(BaseCallbackHandler):
16
  # TODO: The function on_llm_end is going to be called when the LLM stops sending
17
  # the response. Use the crud.add_message function to capture that response.
18
  type = 'AI'
19
- user_data = crud.get_or_create(db, self.user_request.username)
20
  user_id = user_data.user_id
21
  timestamp = datetime.now()
22
  message = outputs.generations[0][0].text # answer from the prompt message
 
16
  # TODO: The function on_llm_end is going to be called when the LLM stops sending
17
  # the response. Use the crud.add_message function to capture that response.
18
  type = 'AI'
19
+ user_data = crud.get_or_create(self.db, self.user_request.username)
20
  user_id = user_data.user_id
21
  timestamp = datetime.now()
22
  message = outputs.generations[0][0].text # answer from the prompt message
chains.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from langchain_huggingface import HuggingFaceEndpoint
3
  from langchain_core.runnables import RunnablePassthrough
4
 
5
  import schemas
@@ -17,19 +17,41 @@ from transformers import AutoTokenizer
17
 
18
  data_indexer = DataIndexer()
19
 
20
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
21
-
22
- llm = HuggingFaceEndpoint(
23
- repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
24
-
25
- huggingfacehub_api_token=os.environ['HF_TOKEN'],
26
- max_new_tokens=512,
27
- stop_sequences=[tokenizer.eos_token],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  streaming=True,
29
- # task="conversational",
30
- task="text-generation",
 
 
31
  )
32
 
 
 
33
  simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
34
 
35
  # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
 
1
  import os
2
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
3
  from langchain_core.runnables import RunnablePassthrough
4
 
5
  import schemas
 
17
 
18
  data_indexer = DataIndexer()
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
21
+
22
+ # llm = HuggingFaceEndpoint(
23
+ # # repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
24
+ # repo_id="deepseek-ai/DeepSeek-R1-0528",
25
+ # huggingfacehub_api_token=os.environ['HF_TOKEN'],
26
+ # max_new_tokens=512,
27
+ # stop_sequences=["<|eot_id|>"],
28
+ # streaming=True,
29
+ # # task="conversational",
30
+ # task="text-generation",
31
+ # # provider='novita',
32
+ # # temperature=0.7,
33
+ # )
34
+ # llm = HuggingFaceEndpoint(
35
+ # repo_id="deepseek-ai/DeepSeek-R1-0528",
36
+ # max_new_tokens=512,
37
+ # temperature=0.5,
38
+ # huggingfacehub_api_token=os.environ['HF_TOKEN'],
39
+ # provider="auto",
40
+ # )
41
+
42
+ llm_endpoint = HuggingFaceEndpoint(
43
+ repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
44
+ task="text-generation",
45
+ max_new_tokens=100,
46
  streaming=True,
47
+ do_sample=False,
48
+ temperature=0.9,
49
+ repetition_penalty=1.03,
50
+ provider="auto",
51
  )
52
 
53
+ llm = ChatHuggingFace(llm=llm_endpoint)
54
+
55
  simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
56
 
57
  # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
data_indexing.py CHANGED
@@ -80,7 +80,7 @@ class DataIndexer:
80
  # values = None
81
 
82
  # TODO: create a list of unique identifiers for each element in the batch with the uuid package.
83
- vector_ids = [uuid.uuid4() for _ in batch]
84
 
85
  # TODO: create a list of dictionaries representing the metadata. Capture the text data
86
  # with the "text" key, and make sure to capture the rest of the doc.metadata.
 
80
  # values = None
81
 
82
  # TODO: create a list of unique identifiers for each element in the batch with the uuid package.
83
+ vector_ids = [str(uuid.uuid4()) for _ in batch]
84
 
85
  # TODO: create a list of dictionaries representing the metadata. Capture the text data
86
  # with the "text" key, and make sure to capture the rest of the doc.metadata.
main.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain_core.runnables import Runnable
2
  from langchain_core.callbacks import BaseCallbackHandler
3
- from fastapi import FastAPI, Request, Depends
4
  from sse_starlette.sse import EventSourceResponse
5
  from langserve.serialization import WellKnownLCSerializer
6
  from typing import List
@@ -36,13 +36,27 @@ def get_db():
36
  db.close()
37
 
38
 
 
 
 
 
 
 
39
  async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
40
- for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
41
- data = WellKnownLCSerializer().dumps(output).decode("utf-8")
42
- yield {'data': data, "event": "data"}
 
 
 
 
 
 
43
  yield {"event": "end"}
44
 
45
 
 
 
46
  @app.post("/simple/stream")
47
  async def simple_stream(request: Request):
48
  data = await request.json()
@@ -56,11 +70,13 @@ async def formatted_stream(request: Request):
56
  try:
57
  data = await request.json()
58
  user_question = schemas.UserQuestion(**data['input'])
59
- return EventSourceResponse(
60
  generate_stream(
61
  input_data = user_question,
62
  runnable = formatted_chain )
63
  )
 
 
64
  except ValueError as e:
65
  raise HTTPException(status_code=400, detail=str(e))
66
  except Exception as e:
@@ -103,7 +119,11 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
103
 
104
  _ = crud.add_message(db,add_message, username = user_request.username)
105
  # chat history contains: [{ message, type, timestamp}]
106
- return EventSourceResponse(generate_stream(history_input, history_chain))
 
 
 
 
107
  # raise NotImplemented
108
 
109
 
@@ -119,7 +139,7 @@ async def rag_stream(request: Request, db: Session = Depends(get_db)):
119
  data = await request.json()
120
  user_request = schemas.UserRequest(**data['input'])
121
  messages = crud.get_user_chat_history(db, user_request.username)
122
- chat_history = format_chat_history(messages)
123
 
124
  history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question)
125
  ## adding messgae to message database
@@ -149,14 +169,14 @@ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
149
  # - We create an instance of HistoryInput by using format_chat_history.
150
  # - We use the history input within the filtered rag chain.
151
  data = await request.json()
152
- user_request = models.UserRequest(data)
153
 
154
  messages = db.Query(
155
  Message.message,
156
  Message.type,
157
  Message.timestamp
158
  ).filter(Message.user_id == user_request.username)
159
- chat_history = format_chat_history(messages)
160
 
161
  history_input = schemas.HistoryInput(
162
  chat_history = format_chat_history(chat_history),
 
1
  from langchain_core.runnables import Runnable
2
  from langchain_core.callbacks import BaseCallbackHandler
3
+ from fastapi import FastAPI, Request, Depends, HTTPException
4
  from sse_starlette.sse import EventSourceResponse
5
  from langserve.serialization import WellKnownLCSerializer
6
  from typing import List
 
36
  db.close()
37
 
38
 
39
+ # async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
40
+ # for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
41
+ # data = WellKnownLCSerializer().dumps(output).decode("utf-8")
42
+ # yield {'data': data, "event": "data"}
43
+ # yield {"event": "end"}
44
+
45
  async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
46
+ for chunk in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
47
+ # ChatHuggingFace returns message chunks with content attribute
48
+ if hasattr(chunk, 'content'):
49
+ content = chunk.content
50
+ else:
51
+ content = str(chunk)
52
+
53
+ if content: # Only yield non-empty content
54
+ yield {'data': content, "event": "data"}
55
  yield {"event": "end"}
56
 
57
 
58
+
59
+
60
  @app.post("/simple/stream")
61
  async def simple_stream(request: Request):
62
  data = await request.json()
 
70
  try:
71
  data = await request.json()
72
  user_question = schemas.UserQuestion(**data['input'])
73
+ output = EventSourceResponse(
74
  generate_stream(
75
  input_data = user_question,
76
  runnable = formatted_chain )
77
  )
78
+ # print(output.generations[0][0].text)
79
+ return output
80
  except ValueError as e:
81
  raise HTTPException(status_code=400, detail=str(e))
82
  except Exception as e:
 
119
 
120
  _ = crud.add_message(db,add_message, username = user_request.username)
121
  # chat history contains: [{ message, type, timestamp}]
122
+
123
+ output = EventSourceResponse(generate_stream(history_input, history_chain))
124
+ LogResponseCallback.on_llm_end(outputs = output)
125
+
126
+ return output
127
  # raise NotImplemented
128
 
129
 
 
139
  data = await request.json()
140
  user_request = schemas.UserRequest(**data['input'])
141
  messages = crud.get_user_chat_history(db, user_request.username)
142
+ chat_history = messages
143
 
144
  history_input = schemas.HistoryInput( chat_history = format_chat_history(chat_history), question=user_request.question)
145
  ## adding messgae to message database
 
169
  # - We create an instance of HistoryInput by using format_chat_history.
170
  # - We use the history input within the filtered rag chain.
171
  data = await request.json()
172
+ user_request = models.UserRequest(**dat['input'])
173
 
174
  messages = db.Query(
175
  Message.message,
176
  Message.type,
177
  Message.timestamp
178
  ).filter(Message.user_id == user_request.username)
179
+ chat_history = messages
180
 
181
  history_input = schemas.HistoryInput(
182
  chat_history = format_chat_history(chat_history),
prompts.py CHANGED
@@ -1,7 +1,7 @@
1
  from langchain_core.prompts import PromptTemplate
2
  from typing import List
3
  import models
4
-
5
 
6
  def format_prompt(prompt) -> PromptTemplate:
7
  # TODO: format the input prompt by using the model specific instruction template
@@ -12,10 +12,20 @@ def format_prompt(prompt) -> PromptTemplate:
12
  {prompt}<|eot_id|>
13
  <|start_header_id|>assistant<|end_header_id|>
14
  """
 
 
 
 
 
 
 
 
 
 
15
 
16
- prompt_template = PromptTemplate(
17
  # input_variables=["question"], the variables will be auto detected by langchain package
18
- template = template
19
  )
20
  # TODO: return a langchain PromptTemplate
21
  return prompt_template
@@ -26,7 +36,7 @@ def format_chat_history(messages: List[models.Message]):
26
  # the list of Message into a text of chat history.
27
  chat_history = ""
28
  for msg in messages:
29
- chat_history += msg['message']
30
  chat_history += "\n---\n"
31
  # combined all messages from the list for sending it to the model prompt.
32
  return chat_history
@@ -48,6 +58,7 @@ def format_context(docs: List[str]):
48
 
49
  raw_prompt = "{question}"
50
 
 
51
  # TODO: Create the history_prompt prompt that will capture the question and the conversation history.
52
  # The history_prompt needs a {chat_history} placeholder and a {question} placeholder.
53
  history_prompt: str = """
 
1
  from langchain_core.prompts import PromptTemplate
2
  from typing import List
3
  import models
4
+ from transformers import AutoTokenizer
5
 
6
  def format_prompt(prompt) -> PromptTemplate:
7
  # TODO: format the input prompt by using the model specific instruction template
 
12
  {prompt}<|eot_id|>
13
  <|start_header_id|>assistant<|end_header_id|>
14
  """
15
+ raw_template = [
16
+ {"role": "system", "content":"You are a helpful assistant." },
17
+ {"role": "user", "content": "{{prompt}}"},
18
+ ]
19
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
20
+ formatted_template = tokenizer.apply_chat_template(
21
+ raw_template,
22
+ tokenize=False,
23
+ add_generation_prompt=True
24
+ )
25
 
26
+ prompt_template = PromptTemplate.from_template(
27
  # input_variables=["question"], the variables will be auto detected by langchain package
28
+ formatted_template
29
  )
30
  # TODO: return a langchain PromptTemplate
31
  return prompt_template
 
36
  # the list of Message into a text of chat history.
37
  chat_history = ""
38
  for msg in messages:
39
+ chat_history += '{}:{}'.format(msg.type, msg.message)
40
  chat_history += "\n---\n"
41
  # combined all messages from the list for sending it to the model prompt.
42
  return chat_history
 
58
 
59
  raw_prompt = "{question}"
60
 
61
+
62
  # TODO: Create the history_prompt prompt that will capture the question and the conversation history.
63
  # The history_prompt needs a {chat_history} placeholder and a {question} placeholder.
64
  history_prompt: str = """
test.db CHANGED
Binary files a/test.db and b/test.db differ