Spaces:
Build error
Build error
Save History as Summary
Browse files- app.py +1 -5
- rag_agent.py +70 -12
app.py
CHANGED
|
@@ -16,7 +16,7 @@ class ChatBot:
|
|
| 16 |
def get_response(self, message):
|
| 17 |
return self.rag_agent.get_response(message)
|
| 18 |
|
| 19 |
-
def chat(self, message
|
| 20 |
time.sleep(1)
|
| 21 |
bot_response = self.get_response(message)
|
| 22 |
self.message_history.append((message, bot_response))
|
|
@@ -99,10 +99,6 @@ def create_chat_interface(rag_agent=rag_agent):
|
|
| 99 |
history[-1][1] = bot_response
|
| 100 |
return history
|
| 101 |
|
| 102 |
-
txt_msg = txt.submit(user_message, [txt, chatbot_component], [txt, chatbot_component], queue=False).then(
|
| 103 |
-
bot_message, chatbot_component, chatbot_component
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
submit_btn.click(user_message, [txt, chatbot_component], [txt, chatbot_component], queue=False).then(
|
| 107 |
bot_message, chatbot_component, chatbot_component
|
| 108 |
)
|
|
|
|
| 16 |
def get_response(self, message):
|
| 17 |
return self.rag_agent.get_response(message)
|
| 18 |
|
| 19 |
+
def chat(self, message):
|
| 20 |
time.sleep(1)
|
| 21 |
bot_response = self.get_response(message)
|
| 22 |
self.message_history.append((message, bot_response))
|
|
|
|
| 99 |
history[-1][1] = bot_response
|
| 100 |
return history
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
submit_btn.click(user_message, [txt, chatbot_component], [txt, chatbot_component], queue=False).then(
|
| 103 |
bot_message, chatbot_component, chatbot_component
|
| 104 |
)
|
rag_agent.py
CHANGED
|
@@ -17,16 +17,18 @@ class RAGAgent:
|
|
| 17 |
retriever=retriever,
|
| 18 |
reranker=reranker,
|
| 19 |
anthropic_api_key: str = os.environ["anthropic_api_key"],
|
| 20 |
-
|
| 21 |
max_tokens: int = 1024,
|
| 22 |
temperature: float = 0.0,
|
| 23 |
):
|
| 24 |
self.retriever = retriever
|
| 25 |
self.reranker = reranker
|
| 26 |
self.client = Anthropic(api_key=anthropic_api_key)
|
| 27 |
-
self.
|
| 28 |
self.max_tokens = max_tokens
|
| 29 |
self.temperature = temperature
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def get_context(self, query: str) -> List[str]:
|
| 32 |
# Get initial candidates from retriever
|
|
@@ -37,29 +39,85 @@ class RAGAgent:
|
|
| 37 |
|
| 38 |
return context
|
| 39 |
|
| 40 |
-
def generate_prompt(self, context: List[str]) -> str:
|
| 41 |
context = "\n".join(context)
|
|
|
|
|
|
|
| 42 |
prompt = f"""
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
"""
|
| 45 |
return prompt
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def get_response(self, question: str) -> str:
|
| 48 |
# Get relevant context
|
| 49 |
-
context = self.get_context(question)
|
| 50 |
|
| 51 |
-
# Generate prompt with context
|
| 52 |
-
prompt = self.generate_prompt(context)
|
| 53 |
|
| 54 |
# Get response from Claude
|
| 55 |
response = self.client.messages.create(
|
| 56 |
-
model=self.
|
| 57 |
max_tokens=self.max_tokens,
|
| 58 |
temperature=self.temperature,
|
| 59 |
messages=[
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
)
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
retriever=retriever,
|
| 18 |
reranker=reranker,
|
| 19 |
anthropic_api_key: str = os.environ["anthropic_api_key"],
|
| 20 |
+
model_name: str = "claude-3-5-sonnet-20241022",
|
| 21 |
max_tokens: int = 1024,
|
| 22 |
temperature: float = 0.0,
|
| 23 |
):
|
| 24 |
self.retriever = retriever
|
| 25 |
self.reranker = reranker
|
| 26 |
self.client = Anthropic(api_key=anthropic_api_key)
|
| 27 |
+
self.model_name = model_name
|
| 28 |
self.max_tokens = max_tokens
|
| 29 |
self.temperature = temperature
|
| 30 |
+
self.conversation_summary = ""
|
| 31 |
+
self.messages = []
|
| 32 |
|
| 33 |
def get_context(self, query: str) -> List[str]:
|
| 34 |
# Get initial candidates from retriever
|
|
|
|
| 39 |
|
| 40 |
return context
|
| 41 |
|
| 42 |
+
def generate_prompt(self, context: List[str], conversation_summary: str = "") -> str:
|
| 43 |
context = "\n".join(context)
|
| 44 |
+
summary_context = f"\nืกืืืื ืืฉืืื ืขื ืื:\n{conversation_summary}" if conversation_summary else ""
|
| 45 |
+
|
| 46 |
prompt = f"""
|
| 47 |
+
ืืชื ืจืืคื ืฉืื ืืื, ืืืืจ ืขืืจืืช ืืืื. ืงืืจืืื ืื 'ืจืืคื ืืฉืื ืืื ืืืืงืืจืื ื ืืขืืจื ืืจืืฉืื'.{summary_context}
|
| 48 |
+
ืขื ื ืืืืืคื ืขื ืืฉืืื ืฉืื ืขื ืกืื ืืงืื ืืงืก ืืื: {context}.
|
| 49 |
+
ืืืกืฃ ืืื ืฉืืืชืจ ืคืจืืื, ืืืื ืฉืืชืืืืจ ืืืื ืชืงืื ืืืคื.
|
| 50 |
+
ืชืขืฆืืจ ืืฉืืชื ืืจืืืฉ ืฉืืืฆืืช ืืช ืขืฆืื. ืื ืชืืฆืื ืืืจืื.
|
| 51 |
+
ืืื ืชืขื ื ืืฉืคืืช ืฉืื ืื ืขืืจืืช.
|
| 52 |
"""
|
| 53 |
return prompt
|
| 54 |
|
| 55 |
+
def update_summary(self, question: str, answer: str) -> str:
|
| 56 |
+
"""Update the conversation summary with the new interaction"""
|
| 57 |
+
summary_prompt = {
|
| 58 |
+
"model": self.model_name,
|
| 59 |
+
"max_tokens": 500,
|
| 60 |
+
"temperature": 0.0,
|
| 61 |
+
"messages": [
|
| 62 |
+
{
|
| 63 |
+
"role": "user",
|
| 64 |
+
"content": f"""ืกืื ืืช ืืฉืืื ืืขืืจืืช, ืื ื ืกืืืื ืืฉืืื ืขื ืื:
|
| 65 |
+
{self.conversation_summary if self.conversation_summary else "ืืื ืฉืืื ืงืืืืช."}
|
| 66 |
+
|
| 67 |
+
ืืื ืืจืืงืฆืื ืืืฉื:
|
| 68 |
+
ืฉืืืช ืืืืืคื: {question}
|
| 69 |
+
ืชืฉืืืช ืืจืืคื: {answer}
|
| 70 |
+
|
| 71 |
+
ืื ื ืกืคืง ืกืืืื ืืขืืืื ืฉืืืื ืืช ืืืืืข ืืจืคืืื ืืืกืืืื ืืงืืื ืื ืืกืฃ ืืืืฉ ืขื ืืืื ืืจืงืฆืื ืืืืฉื. ืืกืืืื ืฆืจืื ืืืืืช ืชืืฆืืชื ืขื 100 ืืืื.
|
| 72 |
+
ืืชืจ ืขื ืืืืข ืื ืจืืืื ืื ืืืกืืืืืื ืืงืืืืื"""
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
response = self.client.messages.create(**summary_prompt)
|
| 79 |
+
self.conversation_summary = response.content[0].text
|
| 80 |
+
return self.conversation_summary
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error updating summary: {e}")
|
| 83 |
+
return self.get_basic_summary()
|
| 84 |
+
|
| 85 |
+
def get_basic_summary(self) -> str:
|
| 86 |
+
"""Fallback method for basic summary"""
|
| 87 |
+
summary = []
|
| 88 |
+
for i in range(0, len(self.messages), 2):
|
| 89 |
+
if i + 1 < len(self.messages):
|
| 90 |
+
summary.append(f"ืฉืืืช ืืืืืคื: {self.messages[i]['content']}")
|
| 91 |
+
summary.append(f"ืชืฉืืืช ืืจืืคื ืฉืื ืืื: {self.messages[i + 1]['content']}\n")
|
| 92 |
+
return "\n".join(summary)
|
| 93 |
+
|
| 94 |
def get_response(self, question: str) -> str:
|
| 95 |
# Get relevant context
|
| 96 |
+
context = self.get_context(question + self.conversation_summary)
|
| 97 |
|
| 98 |
+
# Generate prompt with context and current conversation summary
|
| 99 |
+
prompt = self.generate_prompt(context, self.conversation_summary)
|
| 100 |
|
| 101 |
# Get response from Claude
|
| 102 |
response = self.client.messages.create(
|
| 103 |
+
model=self.model_name,
|
| 104 |
max_tokens=self.max_tokens,
|
| 105 |
temperature=self.temperature,
|
| 106 |
messages=[
|
| 107 |
+
{"role": "assistant", "content": prompt},
|
| 108 |
+
{"role": "user", "content": f"{question}"}
|
| 109 |
+
]
|
| 110 |
)
|
| 111 |
|
| 112 |
+
answer = response.content[0].text
|
| 113 |
+
|
| 114 |
+
# Store messages for history
|
| 115 |
+
self.messages.extend([
|
| 116 |
+
{"role": "user", "content": question},
|
| 117 |
+
{"role": "assistant", "content": answer}
|
| 118 |
+
])
|
| 119 |
+
|
| 120 |
+
# Update conversation summary
|
| 121 |
+
self.update_summary(question, answer)
|
| 122 |
+
|
| 123 |
+
return answer
|