alifyad commited on
Commit
a01c8a5
·
verified ·
1 Parent(s): 6633a06

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +1220 -0
app.py ADDED
@@ -0,0 +1,1220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ######################## WRITE YOUR CODE HERE #########################
3
+ # Import necessary libraries
4
+ import os # Interacting with the operating system (reading/writing files)
5
+ import chromadb # High-performance vector database for storing/querying dense vectors
6
+ from dotenv import load_dotenv # Loading environment variables from a .env file
7
+ import json # Parsing and handling JSON data
8
+
9
+ # LangChain imports
10
+ from langchain_core.documents import Document # Document data structures
11
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
12
+ from langchain_core.output_parsers import StrOutputParser # String output parser
13
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
14
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
15
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
16
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
17
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
18
+
19
+ # LangChain community & experimental imports
20
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
21
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
22
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
23
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
24
+ from langchain.text_splitter import (
25
+ CharacterTextSplitter, # Splitting text by characters
26
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
27
+ )
28
+ from langchain_core.tools import tool
29
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
30
+ from langchain_core.prompts import ChatPromptTemplate
31
+
32
+ # LangChain OpenAI imports
33
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
34
+ #from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
35
+ from langchain.memory import ConversationSummaryBufferMemory
36
+
37
+ from langchain_openai import ChatOpenAI
38
+ from langchain_community.embeddings import OpenAIEmbeddings
39
+ from langchain_openai import OpenAIEmbeddings
40
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
41
+ from langchain_community.vectorstores import Chroma
42
+
43
+
44
+ from langchain_community.utilities.sql_database import SQLDatabase
45
+ from langchain_community.agent_toolkits import create_sql_agent
46
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
47
+ from langchain_core.tools import tool
48
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
49
+
50
+ # LlamaParse & LlamaIndex imports
51
+ from llama_parse import LlamaParse # Document parsing library
52
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
53
+
54
+ # LangGraph import
55
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
56
+
57
+ # Pydantic import
58
+ from pydantic import BaseModel # Pydantic for data validation
59
+
60
+ # Typing imports
61
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
62
+
63
+ # Other utilities
64
+ import numpy as np # Numpy for numerical operations
65
+ from groq import Groq
66
+ from mem0 import MemoryClient
67
+ import streamlit as st
68
+ from datetime import datetime
69
+
70
+ #====================================SETUP=====================================#
71
+ # Fetch secrets from Hugging Face Spaces
72
+ api_key = os.environ["API_KEY"]
73
+ endpoint = os.environ["OPENAI_API_BASE"]
74
+ llama_api_key = os.environ['GROQ_API_KEY']
75
+ mem0_api_key = os.environ['mem0']
76
+
77
+ # Initialize the OpenAI embedding function for Chroma
78
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
79
+ api_base=endpoint, # Complete the code to define the API base endpoint
80
+ api_key=api_key, # Complete the code to define the API key
81
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
82
+ )
83
+
84
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
85
+
86
+ # Initialize the OpenAI Embeddings
87
+ embedding_model = OpenAIEmbeddings(
88
+ openai_api_base=endpoint,
89
+ openai_api_key=api_key,
90
+ model='text-embedding-ada-002'
91
+ )
92
+
93
+
94
+ # Initialize the Chat OpenAI model
95
+ llm = ChatOpenAI(
96
+ openai_api_base=endpoint,
97
+ openai_api_key=api_key,
98
+ model="gpt-4o-mini",
99
+ streaming=False
100
+ )
101
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
102
+
103
+ # set the LLM and embedding model in the LlamaIndex settings.
104
+ Settings.llm = llm # Complete the code to define the LLM model
105
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
106
+
107
+ #================================Creating Langgraph agent======================#
108
+
109
+ class AgentState(TypedDict):
110
+ query: str # The current user query
111
+ expanded_query: str # The expanded version of the user query
112
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
113
+ response: str # The generated response to the user query
114
+ precision_score: float # The precision score of the response
115
+ groundedness_score: float # The groundedness score of the response
116
+ groundedness_loop_count: int # Counter for groundedness refinement loops
117
+ precision_loop_count: int # Counter for precision refinement loops
118
+ feedback: str
119
+ query_feedback: str
120
+ groundedness_check: bool
121
+ loop_max_iter: int
122
+
123
+ def expand_query(state):
124
+ """
125
+ Expands the user query to improve retrieval of nutrition disorder-related information.
126
+
127
+ Args:
128
+ state (Dict): The current state of the workflow, containing the user query.
129
+
130
+ Returns:
131
+ Dict: The updated state with the expanded query.
132
+ """
133
+ print("---------Expanding Query---------")
134
+ system_message = '''You are a climate change domain expert assisting in answering questions related to climate change mitigation and climate change solutions information.
135
+ Convert the user query into something that an environment professional would understand. Use domain related words.
136
+ Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
137
+ or common synonyms for key words in the question, make sure to return multiple versions \
138
+ of the query with the different phrasings.
139
+
140
+ If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
141
+
142
+ If there are acronyms or words you are not familiar with, do not try to rephrase them.
143
+
144
+ Return only 3 versions of the question as a list.
145
+ Generate only a list of questions. Do not mention anything before or after the list.
146
+
147
+ Question:
148
+ {query}'''
149
+
150
+
151
+ expand_prompt = ChatPromptTemplate.from_messages([
152
+ ("system", system_message),
153
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
154
+
155
+ ])
156
+
157
+ chain = expand_prompt | llm | StrOutputParser()
158
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
159
+ print("expanded_query", expanded_query)
160
+ state["expanded_query"] = expanded_query
161
+ return state
162
+
163
+
164
+ # Initialize the Chroma vector store for retrieving documents
165
+ vector_store = Chroma(
166
+ collection_name="climateBot",
167
+ persist_directory="./climateBot_db",
168
+ embedding_function=embedding_model
169
+
170
+ )
171
+
172
+ # Create a retriever from the vector store
173
+ retriever = vector_store.as_retriever(
174
+ search_type='similarity',
175
+ search_kwargs={'k': 3}
176
+ )
177
+
178
+ def retrieve_context(state):
179
+ """
180
+ Retrieves context from the vector store using the expanded or original query.
181
+
182
+ Args:
183
+ state (Dict): The current state of the workflow, containing the query and expanded query.
184
+
185
+ Returns:
186
+ Dict: The updated state with the retrieved context.
187
+ """
188
+ print("---------retrieve_context---------")
189
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
190
+ #print("Query used for retrieval:", query) # Debugging: Print the query
191
+
192
+ # Retrieve documents from the vector store
193
+ docs = retriever.invoke(query)
194
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
195
+
196
+ # Extract both page_content and metadata from each document
197
+ context= [
198
+ {
199
+ "content": doc.page_content, # The actual content of the document
200
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
201
+ }
202
+ for doc in docs
203
+ ]
204
+ state['context'] = context # Complete the code to define the key for storing the context
205
+ print("Extracted context with metadata:", context)
206
+
207
+ return state
208
+
209
+
210
+
211
+ def craft_response(state: Dict) -> Dict:
212
+ """
213
+ Generates a response using the retrieved context, focusing on nutrition disorders.
214
+
215
+ Args:
216
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
217
+
218
+ Returns:
219
+ Dict: The updated state with the generated response.
220
+ """
221
+ print("---------craft_response---------")
222
+ system_message = '''you are a smart nutrition disorder specialist. Use the context and feedback to respond to the query.
223
+ The answer you provide must come from the context and feedback provided.
224
+ If information provided is not enough to answer the query response with 'I don't know the answer. Not in my records'''
225
+
226
+ response_prompt = ChatPromptTemplate.from_messages([
227
+ ("system", system_message),
228
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
229
+ ])
230
+
231
+ chain = response_prompt | llm
232
+ response = chain.invoke({
233
+ "query": state['query'],
234
+ "context": "\n".join([doc["content"] for doc in state['context']]),
235
+ "feedback": state['feedback'] # add feedback to the prompt
236
+ })
237
+ state['response'] = response
238
+ print("intermediate response: ", response)
239
+
240
+ return state
241
+
242
+
243
+
244
+ def score_groundedness(state: Dict) -> Dict:
245
+ """
246
+ Checks whether the response is grounded in the retrieved context.
247
+
248
+ Args:
249
+ state (Dict): The current state of the workflow, containing the response and context.
250
+
251
+ Returns:
252
+ Dict: The updated state with the groundedness score.
253
+ """
254
+ print("---------check_groundedness---------")
255
+ system_message = '''Your task is to judge the groundedness of the response based on the context.
256
+ For each statement you must return verdict as 1 if the response is completely grounded in the context and 0 if the response is completely hallucinated.'''
257
+
258
+ groundedness_prompt = ChatPromptTemplate.from_messages([
259
+ ("system", system_message),
260
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
261
+ ])
262
+
263
+ chain = groundedness_prompt | llm | StrOutputParser()
264
+ groundedness_score = float(chain.invoke({
265
+ "context": "\n".join([doc["content"] for doc in state['context']]),
266
+ "response": state['response'] # Complete the code to define the response
267
+ }))
268
+ print("groundedness_score: ", groundedness_score)
269
+ state['groundedness_loop_count'] += 1
270
+ print("#########Groundedness Incremented###########")
271
+ state['groundedness_score'] = groundedness_score
272
+
273
+ return state
274
+
275
+
276
+
277
+ def check_precision(state: Dict) -> Dict:
278
+ """
279
+ Checks whether the response precisely addresses the user’s query.
280
+
281
+ Args:
282
+ state (Dict): The current state of the workflow, containing the query and response.
283
+
284
+ Returns:
285
+ Dict: The updated state with the precision score.
286
+ """
287
+ print("---------check_precision---------")
288
+ system_message = '''Given the query, response and context verify if the context was useful in arriving at the given answer.
289
+ Give precision score of "1" if useful and "0" if it was not useful.'''
290
+
291
+ precision_prompt = ChatPromptTemplate.from_messages([
292
+ ("system", system_message),
293
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
294
+ ])
295
+
296
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
297
+ precision_score = float(chain.invoke({
298
+ "query": state['query'],
299
+ "response":state['response'] # Complete the code to access the response from the state
300
+ }))
301
+ state['precision_score'] = precision_score
302
+ print("precision_score:", precision_score)
303
+ state['precision_loop_count'] +=1
304
+ print("#########Precision Incremented###########")
305
+ return state
306
+
307
+
308
+
309
+ def refine_response(state: Dict) -> Dict:
310
+ """
311
+ Suggests improvements for the generated response.
312
+
313
+ Args:
314
+ state (Dict): The current state of the workflow, containing the query and response.
315
+
316
+ Returns:
317
+ Dict: The updated state with response refinement suggestions.
318
+ """
319
+ print("---------refine_response---------")
320
+
321
+ system_message = '''Your task is to refine the AI generated response by improving the accuracy and completeness of the response based on the contexxt.'''
322
+
323
+ refine_response_prompt = ChatPromptTemplate.from_messages([
324
+ ("system", system_message),
325
+ ("user", "Query: {query}\nResponse: {response}\n\n"
326
+ "What improvements can be made to enhance accuracy and completeness?")
327
+ ])
328
+
329
+ chain = refine_response_prompt | llm| StrOutputParser()
330
+
331
+ # Store response suggestions in a structured format
332
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
333
+ print("feedback: ", feedback)
334
+ print(f"State: {state}")
335
+ state['feedback'] = feedback
336
+ return state
337
+
338
+
339
+
340
+ def refine_query(state: Dict) -> Dict:
341
+ """
342
+ Suggests improvements for the expanded query.
343
+
344
+ Args:
345
+ state (Dict): The current state of the workflow, containing the query and expanded query.
346
+
347
+ Returns:
348
+ Dict: The updated state with query refinement suggestions.
349
+ """
350
+ print("---------refine_query---------")
351
+ system_message = '''Your task is to refine the expanded query to improve the precision of the AI generated response.'''
352
+
353
+ refine_query_prompt = ChatPromptTemplate.from_messages([
354
+ ("system", system_message),
355
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
356
+ "What improvements can be made for a better search?")
357
+ ])
358
+
359
+ chain = refine_query_prompt | llm | StrOutputParser()
360
+
361
+ # Store refinement suggestions without modifying the original expanded query
362
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
363
+ print("query_feedback: ", query_feedback)
364
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
365
+ state['query_feedback'] = query_feedback
366
+ return state
367
+
368
+
369
+
370
+ def should_continue_groundedness(state):
371
+ """Decides if groundedness is sufficient or needs improvement."""
372
+ print("---------should_continue_groundedness---------")
373
+ print("groundedness loop count: ", state['groundedness_loop_count'])
374
+ if state['groundedness_score'] >= 1: # Complete the code to define the threshold for groundedness
375
+ print("Moving to precision")
376
+ return "check_precision"
377
+ else:
378
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
379
+ return "max_iterations_reached"
380
+ else:
381
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
382
+ return "refine_response"
383
+
384
+
385
+ def should_continue_precision(state: Dict) -> str:
386
+ """Decides if precision is sufficient or needs improvement."""
387
+ print("---------should_continue_precision---------")
388
+ print("precision loop count: ", state['precision_loop_count'])
389
+ if state['precision_score']>=1: # Threshold for precision
390
+ return "pass" # Complete the workflow
391
+ else:
392
+ if state['precision_loop_count']> state['loop_max_iter']: # Maximum allowed loops
393
+ return "max_iterations_reached"
394
+ else:
395
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
396
+ return "refine_query" # Refine the query
397
+
398
+
399
+
400
+
401
+ def max_iterations_reached(state: Dict) -> Dict:
402
+ """Handles the case when the maximum number of iterations is reached."""
403
+ print("---------max_iterations_reached---------")
404
+ """Handles the case when the maximum number of iterations is reached."""
405
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
406
+ state['response'] = response
407
+ return state
408
+
409
+
410
+
411
+ from langgraph.graph import END, StateGraph, START
412
+
413
+ def create_workflow() -> StateGraph:
414
+ """Creates the updated workflow for the AI climate agent."""
415
+
416
+ workflow = StateGraph(AgentState)
417
+
418
+ # Add processing nodes
419
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
420
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
421
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
422
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding.
423
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
424
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
425
+ workflow.add_node("refine_query",refine_query ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
426
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
427
+
428
+ # Main flow edges
429
+ workflow.add_edge(START, "expand_query")
430
+ workflow.add_edge("expand_query", "retrieve_context")
431
+ workflow.add_edge("retrieve_context", "craft_response")
432
+ workflow.add_edge("craft_response", "score_groundedness")
433
+
434
+ # Conditional edges based on groundedness check
435
+ workflow.add_conditional_edges(
436
+ "score_groundedness",
437
+ should_continue_groundedness, # Use the conditional function
438
+ {
439
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
440
+ "refine_response": "refine_response", # If not, refine the response.
441
+ "max_iterations_reached": END # If max loops reached, exit.
442
+ }
443
+ )
444
+
445
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
446
+
447
+ # Conditional edges based on precision check
448
+ workflow.add_conditional_edges(
449
+ "check_precision",
450
+ should_continue_precision, # Use the conditional function
451
+ {
452
+ "pass": END, # If precise, complete the workflow.
453
+ "refine_query": "refine_query", # If imprecise, refine the query.
454
+ "max_iterations_reached": END # If max loops reached, exit.
455
+ }
456
+ )
457
+
458
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
459
+ workflow.add_edge("max_iterations_reached", END)
460
+
461
+ return workflow
462
+
463
+
464
+
465
+ #=========================== Defining the agentic rag tool ====================#
466
+ WORKFLOW_APP = create_workflow().compile()
467
+ @tool
468
+ def agentic_rag(query: str):
469
+ """
470
+ Runs the RAG-based agent with conversation history for context-aware responses.
471
+
472
+ Args:
473
+ query (str): The current user query.
474
+
475
+ Returns:
476
+ Dict[str, Any]: The updated state with the generated response and conversation history.
477
+ """
478
+ # Initialize state with necessary parameters
479
+ inputs = {
480
+ "query": query,
481
+ "expanded_query": "",
482
+ "context": [],
483
+ "response": "",
484
+ "precision_score": 0,
485
+ "groundedness_score":0,
486
+ "groundedness_loop_count": 5,
487
+ "precision_loop_count": 5,
488
+ "feedback": "",
489
+ "query_feedback": "",
490
+ "loop_max_iter": 5
491
+ }
492
+
493
+ output = WORKFLOW_APP.invoke(inputs)
494
+
495
+ return output
496
+
497
+
498
+ #================================ Guardrails ===========================#
499
+ llama_guard_client = Groq(api_key=llama_api_key)
500
+ # Function to filter user input with Llama Guard
501
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
502
+ """
503
+ Filters user input using Llama Guard to ensure it is safe.
504
+
505
+ Parameters:
506
+ - user_input: The input provided by the user.
507
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
508
+
509
+ Returns:
510
+ - The filtered and safe input.
511
+ """
512
+ try:
513
+ # Create a request to Llama Guard to filter the user input
514
+ response = llama_guard_client.chat.completions.create(
515
+ messages=[{"role": "user", "content": user_input}],
516
+ model=model,
517
+ )
518
+ # Return the filtered input
519
+ return response.choices[0].message.content.strip()
520
+ except Exception as e:
521
+ print(f"Error with Llama Guard: {e}")
522
+ return None
523
+
524
+
525
+ #============================= Adding Memory to the agent using mem0 ===============================#
526
+
527
+ class climateBot:
528
+ def __init__(self):
529
+ """
530
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
531
+ """
532
+
533
+
534
+ #====================================SETUP=====================================#
535
+ # Fetch secrets from Hugging Face Spaces
536
+ api_key = os.environ["API_KEY"]
537
+ endpoint = os.environ["OPENAI_API_BASE"]
538
+ llama_api_key = os.environ['GROQ_API_KEY']
539
+ mem0_api_key = os.environ['mem0']
540
+
541
+ # Initialize the OpenAI embedding function for Chroma
542
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
543
+ api_base=endpoint, # Complete the code to define the API base endpoint
544
+ api_key=api_key, # Complete the code to define the API key
545
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
546
+ )
547
+
548
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
549
+
550
+ # Initialize the OpenAI Embeddings
551
+ embedding_model = OpenAIEmbeddings(
552
+ openai_api_base=endpoint,
553
+ openai_api_key=api_key,
554
+ model='text-embedding-ada-002'
555
+ )
556
+
557
+
558
+ # Initialize the Chat OpenAI model
559
+ llm = ChatOpenAI(
560
+ openai_api_base=endpoint,
561
+ openai_api_key=api_key,
562
+ model="gpt-4o-mini",
563
+ streaming=False
564
+ )
565
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
566
+
567
+ # set the LLM and embedding model in the LlamaIndex settings.
568
+ Settings.llm = llm # Complete the code to define the LLM model
569
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
570
+
571
+ #================================Creating Langgraph agent======================#
572
+
573
+ class AgentState(TypedDict):
574
+ query: str # The current user query
575
+ expanded_query: str # The expanded version of the user query
576
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
577
+ response: str # The generated response to the user query
578
+ precision_score: float # The precision score of the response
579
+ groundedness_score: float # The groundedness score of the response
580
+ groundedness_loop_count: int # Counter for groundedness refinement loops
581
+ precision_loop_count: int # Counter for precision refinement loops
582
+ feedback: str
583
+ query_feedback: str
584
+ groundedness_check: bool
585
+ loop_max_iter: int
586
+
587
+ def expand_query(state):
588
+ """
589
+ Expands the user query to improve retrieval of nutrition disorder-related information.
590
+
591
+ Args:
592
+ state (Dict): The current state of the workflow, containing the user query.
593
+
594
+ Returns:
595
+ Dict: The updated state with the expanded query.
596
+ """
597
+ print("---------Expanding Query---------")
598
+ system_message = '''You are a climate expert assisting in answering questions related to climate change mitigation strategies and information.
599
+ Convert the user query into something that a healthcare professional would understand. Use domain related words.
600
+ Perform query expansion on the question received. If there are multiple common ways of phrasing a user question \
601
+ or common synonyms for key words in the question, make sure to return multiple versions \
602
+ of the query with the different phrasings.
603
+
604
+ If the query has multiple parts, split them into separate simpler queries. This is the only case where you can generate more than 3 queries.
605
+
606
+ If there are acronyms or words you are not familiar with, do not try to rephrase them.
607
+
608
+ Return only 3 versions of the question as a list.
609
+ Generate only a list of questions. Do not mention anything before or after the list.
610
+
611
+ Question:
612
+ {query}'''
613
+
614
+
615
+ expand_prompt = ChatPromptTemplate.from_messages([
616
+ ("system", system_message),
617
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
618
+
619
+ ])
620
+
621
+ chain = expand_prompt | llm | StrOutputParser()
622
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
623
+ print("expanded_query", expanded_query)
624
+ state["expanded_query"] = expanded_query
625
+ return state
626
+
627
+
628
+ # Initialize the Chroma vector store for retrieving documents
629
+ vector_store = Chroma(
630
+ collection_name="climateBot",
631
+ persist_directory="./climateBot_db",
632
+ embedding_function=embedding_model
633
+
634
+ )
635
+
636
+ # Create a retriever from the vector store
637
+ retriever = vector_store.as_retriever(
638
+ search_type='similarity',
639
+ search_kwargs={'k': 3}
640
+ )
641
+
642
+ def retrieve_context(state):
643
+ """
644
+ Retrieves context from the vector store using the expanded or original query.
645
+
646
+ Args:
647
+ state (Dict): The current state of the workflow, containing the query and expanded query.
648
+
649
+ Returns:
650
+ Dict: The updated state with the retrieved context.
651
+ """
652
+ print("---------retrieve_context---------")
653
+
654
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
655
+
656
+
657
+ # Retrieve documents from the vector store
658
+ docs = retriever.invoke(query)
659
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
660
+
661
+ # Extract both page_content and metadata from each document
662
+ context= [
663
+ {
664
+ "content": doc.page_content, # The actual content of the document
665
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
666
+ }
667
+ for doc in docs
668
+ ]
669
+ state['context'] = context # Complete the code to define the key for storing the context
670
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
671
+
672
+ return state
673
+
674
+
675
+
676
+ def craft_response(state: Dict) -> Dict:
677
+ """
678
+ Generates a response using the retrieved context, focusing on nutrition disorders.
679
+
680
+ Args:
681
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
682
+
683
+ Returns:
684
+ Dict: The updated state with the generated response.
685
+ """
686
+ print("---------craft_response---------")
687
+ system_message = '''you are a smart nutrition disorder specialist. Use the context and feedback to respond to the query.
688
+ The answer you provide must come from the context and feedback provided.
689
+ If information provided is not enough to answer the query response with 'I don't know the answer. Not in my records'''
690
+
691
+ response_prompt = ChatPromptTemplate.from_messages([
692
+ ("system", system_message),
693
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
694
+ ])
695
+
696
+ chain = response_prompt | llm
697
+ response = chain.invoke({
698
+ "query": state['query'],
699
+ "context": "\n".join([doc["content"] for doc in state['context']]),
700
+ "feedback": state['feedback'] # add feedback to the prompt
701
+ })
702
+ state['response'] = response
703
+ print("intermediate response: ", response)
704
+
705
+ return state
706
+
707
+
708
+
709
+ def score_groundedness(state: Dict) -> Dict:
710
+ """
711
+ Checks whether the response is grounded in the retrieved context.
712
+
713
+ Args:
714
+ state (Dict): The current state of the workflow, containing the response and context.
715
+
716
+ Returns:
717
+ Dict: The updated state with the groundedness score.
718
+ """
719
+ print("---------check_groundedness---------")
720
+ system_message = '''Your task is to judge the groundedness of the response based on the context.
721
+ For each statement you must return verdict as 1 if the response is completely grounded in the context and 0 if the response is completely hallucinated.'''
722
+
723
+ groundedness_prompt = ChatPromptTemplate.from_messages([
724
+ ("system", system_message),
725
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
726
+ ])
727
+
728
+ chain = groundedness_prompt | llm | StrOutputParser()
729
+ groundedness_score = float(chain.invoke({
730
+ "context": "\n".join([doc["content"] for doc in state['context']]),
731
+ "response": state['response'] # Complete the code to define the response
732
+ }))
733
+ print("groundedness_score: ", groundedness_score)
734
+ state['groundedness_loop_count'] += 1
735
+ print("#########Groundedness Incremented###########")
736
+ state['groundedness_score'] = groundedness_score
737
+
738
+ return state
739
+
740
+
741
+
742
+ def check_precision(state: Dict) -> Dict:
743
+ """
744
+ Checks whether the response precisely addresses the user’s query.
745
+
746
+ Args:
747
+ state (Dict): The current state of the workflow, containing the query and response.
748
+
749
+ Returns:
750
+ Dict: The updated state with the precision score.
751
+ """
752
+ print("---------check_precision---------")
753
+ system_message = '''Given the query, response and context verify if the context was useful in arriving at the given answer.
754
+ Give precision score of "1" if useful and "0" if it was not useful.'''
755
+
756
+ precision_prompt = ChatPromptTemplate.from_messages([
757
+ ("system", system_message),
758
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
759
+ ])
760
+
761
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
762
+ precision_score = float(chain.invoke({
763
+ "query": state['query'],
764
+ "response":state['response'] # Complete the code to access the response from the state
765
+ }))
766
+ state['precision_score'] = precision_score
767
+ print("precision_score:", precision_score)
768
+ state['precision_loop_count'] +=1
769
+ print("#########Precision Incremented###########")
770
+ return state
771
+
772
+
773
+
774
+ def refine_response(state: Dict) -> Dict:
775
+ """
776
+ Suggests improvements for the generated response.
777
+
778
+ Args:
779
+ state (Dict): The current state of the workflow, containing the query and response.
780
+
781
+ Returns:
782
+ Dict: The updated state with response refinement suggestions.
783
+ """
784
+ print("---------refine_response---------")
785
+
786
+ system_message = '''Your task is to refine the AI generated response by improving the accuracy and completeness of the response based on the contexxt.'''
787
+
788
+ refine_response_prompt = ChatPromptTemplate.from_messages([
789
+ ("system", system_message),
790
+ ("user", "Query: {query}\nResponse: {response}\n\n"
791
+ "What improvements can be made to enhance accuracy and completeness?")
792
+ ])
793
+
794
+ chain = refine_response_prompt | llm| StrOutputParser()
795
+
796
+ # Store response suggestions in a structured format
797
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
798
+ print("feedback: ", feedback)
799
+ print(f"State: {state}")
800
+ state['feedback'] = feedback
801
+ return state
802
+
803
+
804
+
805
+ def refine_query(state: Dict) -> Dict:
806
+ """
807
+ Suggests improvements for the expanded query.
808
+
809
+ Args:
810
+ state (Dict): The current state of the workflow, containing the query and expanded query.
811
+
812
+ Returns:
813
+ Dict: The updated state with query refinement suggestions.
814
+ """
815
+ print("---------refine_query---------")
816
+ system_message = '''Your task is to refine the expanded query to improve the precision of the response.'''
817
+
818
+ refine_query_prompt = ChatPromptTemplate.from_messages([
819
+ ("system", system_message),
820
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
821
+ "What improvements can be made for a better search?")
822
+ ])
823
+
824
+ chain = refine_query_prompt | llm | StrOutputParser()
825
+
826
+ # Store refinement suggestions without modifying the original expanded query
827
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
828
+ print("query_feedback: ", query_feedback)
829
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
830
+ state['query_feedback'] = query_feedback
831
+ return state
832
+
833
+
834
+
835
+ def should_continue_groundedness(state):
836
+ """Decides if groundedness is sufficient or needs improvement."""
837
+ print("---------should_continue_groundedness---------")
838
+ print("groundedness loop count: ", state['groundedness_loop_count'])
839
+ if state['groundedness_score'] >= 1: # Complete the code to define the threshold for groundedness
840
+ print("Moving to precision")
841
+ return "check_precision"
842
+ else:
843
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
844
+ return "max_iterations_reached"
845
+ else:
846
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
847
+ return "refine_response"
848
+
849
+
850
+ def should_continue_precision(state: Dict) -> str:
851
+ """Decides if precision is sufficient or needs improvement."""
852
+ print("---------should_continue_precision---------")
853
+ print("precision loop count: ", state['precision_loop_count'])
854
+ if state['precision_score']>=1: # Threshold for precision
855
+ return "pass" # Complete the workflow
856
+ else:
857
+ if state['precision_loop_count']> state['loop_max_iter']: # Maximum allowed loops
858
+ return "max_iterations_reached"
859
+ else:
860
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
861
+ return "refine_query" # Refine the query
862
+
863
+
864
+
865
+
866
+ def max_iterations_reached(state: Dict) -> Dict:
867
+ """Handles the case when the maximum number of iterations is reached."""
868
+ print("---------max_iterations_reached---------")
869
+ """Handles the case when the maximum number of iterations is reached."""
870
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
871
+ state['response'] = response
872
+ return state
873
+
874
+
875
+
876
+ from langgraph.graph import END, StateGraph, START
877
+
878
+ def create_workflow() -> StateGraph:
879
+ """Creates the updated workflow for the AI nutrition agent."""
880
+ workflow = StateGraph(AgentState)
881
+
882
+ # Add processing nodes
883
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
884
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents.
885
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data.
886
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding.
887
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded.
888
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision.
889
+ workflow.add_node("refine_query",refine_query ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
890
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations.
891
+
892
+ # Main flow edges
893
+ workflow.add_edge(START, "expand_query")
894
+ workflow.add_edge("expand_query", "retrieve_context")
895
+ workflow.add_edge("retrieve_context", "craft_response")
896
+ workflow.add_edge("craft_response", "score_groundedness")
897
+
898
+ # Conditional edges based on groundedness check
899
+ workflow.add_conditional_edges(
900
+ "score_groundedness",
901
+ should_continue_groundedness, # Use the conditional function
902
+ {
903
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
904
+ "refine_response": "refine_response", # If not, refine the response.
905
+ "max_iterations_reached": END # If max loops reached, exit.
906
+ }
907
+ )
908
+
909
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
910
+
911
+ # Conditional edges based on precision check
912
+ workflow.add_conditional_edges(
913
+ "check_precision",
914
+ should_continue_precision, # Use the conditional function
915
+ {
916
+ "pass": END, # If precise, complete the workflow.
917
+ "refine_query": "refine_query", # If imprecise, refine the query.
918
+ "max_iterations_reached": END # If max loops reached, exit.
919
+ }
920
+ )
921
+
922
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
923
+ workflow.add_edge("max_iterations_reached", END)
924
+
925
+ return workflow
926
+
927
+
928
+
929
+
930
+ #=========================== Defining the agentic rag tool ====================#
931
+ WORKFLOW_APP = create_workflow().compile()
932
+
933
+ @tool
934
+ def agentic_rag(query: str):
935
+ """
936
+ Runs the RAG-based agent with conversation history for context-aware responses.
937
+
938
+ Args:
939
+ query (str): The current user query.
940
+
941
+ Returns:
942
+ Dict[str, Any]: The updated state with the generated response and conversation history.
943
+ """
944
+ # Initialize state with necessary parameters
945
+ inputs = {
946
+ "query": query,
947
+ "expanded_query": "",
948
+ "context": [],
949
+ "response": "",
950
+ "precision_score": 0,
951
+ "groundedness_score":0,
952
+ "groundedness_loop_count": 5,
953
+ "precision_loop_count": 5,
954
+ "feedback": "",
955
+ "query_feedback": "",
956
+ "loop_max_iter": 5
957
+ }
958
+
959
+ output = WORKFLOW_APP.invoke(inputs)
960
+
961
+ return output
962
+
963
+
964
+ #================================ Guardrails ===========================#
965
+ llama_guard_client = Groq(api_key=llama_api_key)
966
+ # Function to filter user input with Llama Guard
967
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
968
+ """
969
+ Filters user input using Llama Guard to ensure it is safe.
970
+
971
+ Parameters:
972
+ - user_input: The input provided by the user.
973
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
974
+
975
+ Returns:
976
+ - The filtered and safe input.
977
+ """
978
+ try:
979
+ # Create a request to Llama Guard to filter the user input
980
+ response = llama_guard_client.chat.completions.create(
981
+ messages=[{"role": "user", "content": user_input}],
982
+ model=model,
983
+ )
984
+ # Return the filtered input
985
+ return response.choices[0].message.content.strip()
986
+ except Exception as e:
987
+ print(f"Error with Llama Guard: {e}")
988
+ return None
989
+
990
+
991
+ #============================= Adding Memory to the agent using mem0 ===============================#
992
+
993
+ class climateBot:
994
+ def __init__(self):
995
+ """
996
+ Initialize the climateBot class, setting up memory, the LLM client, tools, and the agent executor.
997
+ """
998
+
999
+ # Initialize a memory client to store and retrieve customer interactions
1000
+ #self.memory = MemoryClient(api_key=userdata.get("mem0_api_key")) # Complete the code to define the memory client API key
1001
+ self.memory = MemoryClient(api_key=mem0_api_key)
1002
+ # Initialize the OpenAI client using the provided credentials
1003
+ self.client = ChatOpenAI(
1004
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
1005
+ api_key=os.environ["API_KEY"], # API key for authentication
1006
+ openai_api_base = os.environ["OPENAI_API_BASE"],
1007
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
1008
+ )
1009
+
1010
+ # Define tools available to the chatbot, such as web search
1011
+ tools = [agentic_rag]
1012
+
1013
+ # Define the system prompt to set the behavior of the chatbot
1014
+ system_prompt = """You are a caring and knowledgeable Climate Agent, specializing in climate change mitigation strategies and climate action recommendations. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
1015
+ Guidelines for Interaction:
1016
+ Maintain a polite, professional, and reassuring tone.
1017
+ Show genuine empathy for customer concerns and health challenges.
1018
+ Reference past interactions to provide personalized and consistent advice.
1019
+ Engage with the customer by asking about their location, top climate business priorities and company size before offering recommendations.
1020
+ Ensure consistent and accurate information across conversations.
1021
+ If any detail is unclear or missing, proactively ask for clarification.
1022
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based climate solution insights.
1023
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
1024
+ Your primary goal is to help customers make informed climate solution decisions that align with their specific circumstances and business preferences.
1025
+
1026
+ """
1027
+
1028
+ # Build the prompt template for the agent
1029
+ prompt = ChatPromptTemplate.from_messages([
1030
+ ("system", system_prompt), # System instructions
1031
+ ("human", "{input}"), # Placeholder for human input
1032
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
1033
+ ])
1034
+
1035
+ # Create an agent capable of interacting with tools and executing tasks
1036
+ agent = create_tool_calling_agent(self.client, tools, prompt)
1037
+
1038
+ # Wrap the agent in an executor to manage tool interactions and execution flow
1039
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
1040
+
1041
+
1042
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
1043
+ """
1044
+ Store customer interaction in memory for future reference.
1045
+
1046
+ Args:
1047
+ user_id (str): Unique identifier for the customer.
1048
+ message (str): Customer's query or message.
1049
+ response (str): Chatbot's response.
1050
+ metadata (Dict, optional): Additional metadata for the interaction.
1051
+ """
1052
+ if metadata is None:
1053
+ metadata = {}
1054
+
1055
+ # Add a timestamp to the metadata for tracking purposes
1056
+ metadata["timestamp"] = datetime.now().isoformat()
1057
+
1058
+ # Format the conversation for storage
1059
+ conversation = [
1060
+ {"role": "user", "content": message},
1061
+ {"role": "assistant", "content": response}
1062
+ ]
1063
+
1064
+ # Store the interaction in the memory client
1065
+ self.memory.add(
1066
+ conversation,
1067
+ user_id=user_id,
1068
+ output_format="v1.1",
1069
+ metadata=metadata
1070
+ )
1071
+
1072
+
1073
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
1074
+ """
1075
+ Retrieve past interactions relevant to the current query.
1076
+
1077
+ Args:
1078
+ user_id (str): Unique identifier for the customer.
1079
+ query (str): The customer's current query.
1080
+
1081
+ Returns:
1082
+ List[Dict]: A list of relevant past interactions.
1083
+ """
1084
+ return self.memory.search(
1085
+ query=query, # Search for interactions related to the query
1086
+ user_id=user_id, # Restrict search to the specific user
1087
+ limit=5 # Complete the code to define the limit for retrieved interactions
1088
+ )
1089
+
1090
+
1091
+ def handle_customer_query(self, user_id: str, query: str) -> str:
1092
+ """
1093
+ Process a customer's query and provide a response, taking into account past interactions.
1094
+
1095
+ Args:
1096
+ user_id (str): Unique identifier for the customer.
1097
+ query (str): Customer's query.
1098
+
1099
+ Returns:
1100
+ str: Chatbot's response.
1101
+ """
1102
+
1103
+ # Retrieve relevant past interactions for context
1104
+ relevant_history = self.get_relevant_history(user_id, query)
1105
+
1106
+ # Build a context string from the relevant history
1107
+ context = "Previous relevant interactions:\n"
1108
+ for memory in relevant_history:
1109
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
1110
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
1111
+ context += "---\n"
1112
+
1113
+ # Print context for debugging purposes
1114
+ print("Context: ", context)
1115
+
1116
+ # Prepare a prompt combining past context and the current query
1117
+ prompt = f"""
1118
+ Context:
1119
+ {context}
1120
+
1121
+ Current customer query: {query}
1122
+
1123
+ Provide a helpful response that takes into account any relevant past interactions.
1124
+ """
1125
+
1126
+ # Generate a response using the agent
1127
+ response = self.agent_executor.invoke({"input": prompt})
1128
+
1129
+ # Store the current interaction for future reference
1130
+ self.store_customer_interaction(
1131
+ user_id=user_id,
1132
+ message=query,
1133
+ response=response["output"],
1134
+ metadata={"type": "support_query"}
1135
+ )
1136
+
1137
+ # Return the chatbot's response
1138
+ return response['output']
1139
+
1140
+
1141
+ #=====================User Interface using streamlit ===========================#
1142
+ def climate_streamlit():
1143
+ """
1144
+ A Streamlit-based UI for the Climate Agent.
1145
+ """
1146
+ st.title("ClimateBot")
1147
+ st.write("Ask me anything about climate solutions and actions to help your business achieve net zero.")
1148
+ st.write("Type 'exit' to end the conversation.")
1149
+
1150
+ # Initialize session state for chat history and user_id if they don't exist
1151
+ if 'chat_history' not in st.session_state:
1152
+ st.session_state.chat_history = []
1153
+ if 'user_id' not in st.session_state:
1154
+ st.session_state.user_id = None
1155
+
1156
+ # Login form: Only if user is not logged in
1157
+ if st.session_state.user_id is None:
1158
+ with st.form("login_form", clear_on_submit=True):
1159
+ user_id = st.text_input("Please enter your name to begin:")
1160
+ submit_button = st.form_submit_button("Login")
1161
+ if submit_button and user_id:
1162
+ st.session_state.user_id = user_id
1163
+ st.session_state.chat_history.append({
1164
+ "role": "assistant",
1165
+ "content": f"Welcome, {user_id}! How can I help you climate action recommendations today?"
1166
+ })
1167
+ st.session_state.login_submitted = True # Set flag to trigger rerun
1168
+ if st.session_state.get("login_submitted", False):
1169
+ st.session_state.pop("login_submitted")
1170
+ st.rerun()
1171
+ else:
1172
+ # Display chat history
1173
+ for message in st.session_state.chat_history:
1174
+ with st.chat_message(message["role"]):
1175
+ st.write(message["content"])
1176
+
1177
+ # Chat input with custom placeholder text
1178
+ user_query = st.chat_input("Type your question here (or exit to end): ", key="chat_input") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
1179
+ if user_query:
1180
+ if user_query.lower() == "exit":
1181
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
1182
+ with st.chat_message("user"):
1183
+ st.write("exit")
1184
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about climate action recommendations."
1185
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
1186
+ with st.chat_message("assistant"):
1187
+ st.write(goodbye_msg)
1188
+ st.session_state.user_id = None
1189
+ st.rerun()
1190
+ return
1191
+
1192
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
1193
+ with st.chat_message("user"):
1194
+ st.write(user_query)
1195
+
1196
+ # Filter input using Llama Guard
1197
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
1198
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
1199
+
1200
+ # Check if input is safe based on allowed statuses
1201
+ if filtered_result in ["safe", "unsafe S6", "unsafe S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
1202
+ try:
1203
+ if 'chatbot' not in st.session_state:
1204
+ st.session_state.chatbot = climateBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
1205
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
1206
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
1207
+ st.write(response)
1208
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
1209
+ except Exception as e:
1210
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
1211
+ st.write(error_msg)
1212
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
1213
+ else:
1214
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
1215
+ st.write(inappropriate_msg)
1216
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
1217
+
1218
+ if __name__ == "__main__":
1219
+ climate_streamlit()
1220
+