cheremnm commited on
Commit
9eb3490
·
verified ·
1 Parent(s): be892b2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +755 -0
app.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_core.documents import Document # Document data structures
11
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
12
+ from langchain_core.output_parsers import StrOutputParser # String output parser
13
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
14
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
15
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
16
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
17
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
18
+
19
+ # LangChain community & experimental imports
20
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
21
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
22
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
23
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
24
+ from langchain.text_splitter import (
25
+ CharacterTextSplitter, # Splitting text by characters
26
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
27
+ )
28
+ from langchain_core.tools import tool
29
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
30
+ from langchain_core.prompts import ChatPromptTemplate
31
+
32
+ # LangChain OpenAI imports
33
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
34
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
35
+
36
+ # LlamaParse & LlamaIndex imports
37
+ from llama_parse import LlamaParse # Document parsing library
38
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
39
+
40
+ # LangGraph import
41
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
42
+
43
+ # Pydantic import
44
+ from pydantic import BaseModel # Pydantic for data validation
45
+
46
+ # Typing imports
47
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
48
+
49
+ # Other utilities
50
+ import numpy as np # Numpy for numerical operations
51
+ from groq import Groq
52
+ from mem0 import MemoryClient
53
+ import streamlit as st
54
+ from datetime import datetime
55
+
56
+ #====================================SETUP=====================================#
57
+ # Fetch secrets from Hugging Face Spaces
58
+ api_key = os.environ.get("API_KEY") or config.get("API_KEY")
59
+ endpoint = os.environ.get("OPENAI_API_BASE") or config.get("OPENAI_API_BASE")
60
+ llama_api_key = os.environ.get("GROQ_API_KEY") or config2.get("LLAMA_KEY")
61
+ MEM0_api_key = os.environ.get("MEM0_API_KEY")
62
+
63
+
64
+ # Initialize the OpenAI Embeddings
65
+ embedding_model = OpenAIEmbeddings(
66
+ openai_api_base=endpoint,
67
+ openai_api_key=api_key,
68
+ model='text-embedding-ada-002'
69
+ )
70
+
71
+
72
+ # Initialize the Chat OpenAI model
73
+ llm = ChatOpenAI(
74
+ openai_api_base=endpoint,
75
+ openai_api_key=api_key,
76
+ model="gpt-4o-mini",
77
+ streaming=False
78
+ )
79
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
80
+
81
+ # set the LLM and embedding model in the LlamaIndex settings.
82
+ Settings.llm = llm
83
+ Settings.embedding = embedding_model
84
+
85
+ #================================Creating Langgraph agent======================#
86
+
87
+ class AgentState(TypedDict):
88
+ query: str # The current user query
89
+ expanded_query: str # The expanded version of the user query
90
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
91
+ response: str # The generated response to the user query
92
+ precision_score: float # The precision score of the response
93
+ groundedness_score: float # The groundedness score of the response
94
+ groundedness_loop_count: int # Counter for groundedness refinement loops
95
+ precision_loop_count: int # Counter for precision refinement loops
96
+ feedback: str
97
+ query_feedback: str
98
+ groundedness_check: bool
99
+ loop_max_iter: int
100
+
101
+ def expand_query(state: Dict[str, Any]) -> Dict[str, Any]:
102
+ """
103
+ Expands the user query to improve retrieval of nutrition-disorder information.
104
+
105
+ Args:
106
+ state: Workflow state containing at least 'query' and 'query_feedback'.
107
+
108
+ Returns:
109
+ Workflow state with an additional 'expanded_query' key.
110
+ """
111
+ s: AgentState = state
112
+
113
+ print("---------Expanding Query---------")
114
+ system_message = '''You are an assistant that reformulates vague or short user questions into detailed, domain-specific queries related to nutrition disorders.
115
+
116
+ Examples:
117
+ - Input: "What about iron?"
118
+ Expanded: "What are the common symptoms and treatments for iron deficiency anemia?"
119
+
120
+ - Input: "Diets for gut issues?"
121
+ Expanded: "What dietary recommendations are effective for managing irritable bowel syndrome and promoting gut microbiome health?"
122
+
123
+ - Input: "Sugar"
124
+ Expanded: "What are the risks of high sugar intake in diabetic patients and how can it be managed nutritionally?"
125
+ '''
126
+ expand_prompt = ChatPromptTemplate.from_messages([
127
+ ("system", system_message),
128
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
129
+
130
+ ])
131
+
132
+ chain = expand_prompt | llm | StrOutputParser()
133
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
134
+ print("expanded_query", expanded_query)
135
+ state["expanded_query"] = expanded_query
136
+ return state
137
+
138
+
139
+ chroma_client = chromadb.PersistentClient(path="./combined")
140
+
141
+ vector_store = Chroma(
142
+ client=chroma_client, # <- pass the client you just made
143
+ collection_name="combined",
144
+ embedding_function=embedding_model,
145
+ )
146
+
147
+ # Create a retriever from the vector store
148
+ retriever = vector_store.as_retriever(
149
+ search_type='similarity',
150
+ search_kwargs={'k': 3}
151
+ )
152
+
153
+ def retrieve_context(state):
154
+ """
155
+ Retrieves context from the vector store using the expanded or original query.
156
+
157
+ Args:
158
+ state (Dict): The current state of the workflow, containing the query and expanded query.
159
+
160
+ Returns:
161
+ Dict: The updated state with the retrieved context.
162
+ """
163
+ print("---------retrieve_context---------")
164
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
165
+ #print("Query used for retrieval:", query) # Debugging: Print the query
166
+
167
+ # Retrieve documents from the vector store
168
+ docs = retriever.invoke(query)
169
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
170
+
171
+ # Extract both page_content and metadata from each document
172
+ context = [
173
+ {
174
+
175
+ "content": doc.metadata.get("original_content", doc.page_content),
176
+ "metadata": doc.metadata
177
+ }
178
+ for doc in docs
179
+
180
+ ]
181
+ state['context'] = context # Complete the code to define the key for storing the context
182
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
183
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
184
+ return state
185
+
186
+
187
+
188
+ def craft_response(state):
189
+ """
190
+ Generates a response using the retrieved context, focusing on nutrition disorders.
191
+
192
+ Args:
193
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
194
+
195
+ Returns:
196
+ Dict: The updated state with the generated response.
197
+ """
198
+ print("---------craft_response---------")
199
+ system_message = '''You are a helpful AI assistant trained to support healthcare providers in retrieving
200
+ accurate and relevant information related to nutrition disorders. Your responses must strictly adhere to the
201
+ retrieved context, which is extracted from the Nutritional Medical Reference or similar trusted sources.
202
+
203
+ Do not speculate or introduce external knowledge. Focus only on symptoms, diagnoses, treatment plans, or
204
+ other clinical details found in the context. If the context does not contain enough information to answer
205
+ accurately, clearly state that. Aim for clarity, factual grounding, and relevance to the user's query.'''
206
+
207
+ response_prompt = ChatPromptTemplate.from_messages([
208
+ ("system", system_message),
209
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
210
+ ])
211
+
212
+ chain = response_prompt | llm
213
+ response = chain.invoke({
214
+ "query": state['query'],
215
+ "context": "\n".join([doc["content"] for doc in state['context']]),
216
+ "feedback": state.get('feedback', 'No feedback provided') # add feedback to the prompt
217
+ })
218
+ state['response'] = response
219
+ print("intermediate response: ", response)
220
+
221
+ return state
222
+
223
+
224
+
225
+ def score_groundedness(state):
226
+ """
227
+ Checks whether the response is grounded in the retrieved context.
228
+
229
+ Args:
230
+ state (Dict): The current state of the workflow, containing the response and context.
231
+
232
+ Returns:
233
+ Dict: The updated state with the groundedness score.
234
+ """
235
+ print("---------check_groundedness---------")
236
+
237
+ system_message = '''You are evaluating whether an AI-generated response is grounded in the retrieved context
238
+ provided from nutritional health documents. The context includes evidence and facts, and your task is to
239
+ assign a groundedness score between 0 and 1, where:
240
+
241
+ - 1.0 means the response is fully supported by the context,
242
+ - 0.0 means the response is entirely unsupported.
243
+
244
+ Be strict: if even a part of the response is not traceable to the context, reduce the score. Provide only
245
+ the numeric score.'''
246
+
247
+ groundedness_prompt = ChatPromptTemplate.from_messages([
248
+ ("system", system_message),
249
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
250
+ ])
251
+
252
+ chain = groundedness_prompt | llm | StrOutputParser()
253
+ groundedness_score = float(chain.invoke({
254
+ "context": "\n".join([doc["content"] for doc in state['context']]),
255
+ "response":state['response'] # Complete the code to define the response
256
+ }))
257
+ print("groundedness_score: ", groundedness_score)
258
+ state['groundedness_loop_count'] += 1
259
+ print("#########Groundedness Incremented###########")
260
+ state['groundedness_score'] = groundedness_score
261
+
262
+ return state
263
+
264
+
265
+
266
+ def check_precision(state: Dict) -> Dict:
267
+ """
268
+ Checks whether the response precisely addresses the user’s query.
269
+
270
+ Args:
271
+ state (Dict): The current state of the workflow, containing the query and response.
272
+
273
+ Returns:
274
+ Dict: The updated state with the precision score.
275
+ """
276
+ print("---------check_precision---------")
277
+ system_message = '''You are assessing whether an AI-generated response precisely answers the user's query,
278
+ especially within the domain of nutritional health and disorders. Provide a precision score between 0 and 1:
279
+
280
+ - 1.0: The response fully and directly answers the query with clear relevance.
281
+ - 0.0: The response is vague, unrelated, or fails to address the query.
282
+
283
+ Only return a numeric score.'''
284
+
285
+ precision_prompt = ChatPromptTemplate.from_messages([
286
+ ("system", system_message),
287
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
288
+ ])
289
+
290
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
291
+ precision_score = float(chain.invoke({
292
+ "query": state['query'],
293
+ "response":state['response'] # Complete the code to access the response from the state
294
+ }))
295
+ state['precision_score'] = precision_score
296
+ print("precision_score:", precision_score)
297
+ state['precision_loop_count'] +=1
298
+ print("#########Precision Incremented###########")
299
+
300
+ return state
301
+
302
+
303
+
304
+ def refine_response(state: Dict) -> Dict:
305
+ """
306
+ Suggests improvements for the generated response.
307
+
308
+ Args:
309
+ state (Dict): The current state of the workflow, containing the query and response.
310
+
311
+ Returns:
312
+ Dict: The updated state with response refinement suggestions.
313
+ """
314
+ print("---------refine_response---------")
315
+
316
+ system_message = '''You are an expert assistant helping to improve AI-generated answers related to nutritional disorders.
317
+ Evaluate the response and suggest constructive improvements to enhance accuracy, specificity, and completeness.
318
+ Do not rewrite the response. Instead, point out what is vague, missing, or could be better explained.
319
+ Focus on clinical terminology, nutritional details, and relevance to the user query.'''
320
+
321
+ refine_response_prompt = ChatPromptTemplate.from_messages([
322
+ ("system", system_message),
323
+ ("user", "Query: {query}\nResponse: {response}\n\n"
324
+ "What improvements can be made to enhance accuracy and completeness?")
325
+ ])
326
+
327
+ chain = refine_response_prompt | llm| StrOutputParser()
328
+
329
+ # Store response suggestions in a structured format
330
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
331
+ print("feedback: ", feedback)
332
+ print(f"State: {state}")
333
+ state['feedback'] = feedback
334
+ return state
335
+
336
+
337
+
338
+ def refine_query(state: Dict) -> Dict:
339
+ """
340
+ Suggests improvements for the expanded query.
341
+
342
+ Args:
343
+ state (Dict): The current state of the workflow, containing the query and expanded query.
344
+
345
+ Returns:
346
+ Dict: The updated state with query refinement suggestions.
347
+ """
348
+ print("---------refine_query---------")
349
+ system_message = '''You are an expert retrieval assistant helping to improve search queries related to nutrition disorders.
350
+ Analyze the original and expanded queries and provide suggestions to increase search precision.
351
+ Focus on identifying missing clinical terms, relevant nutritional keywords, or clarifying the scope of the query.
352
+ Do not rewrite the query. Provide suggestions only.'''
353
+
354
+ refine_query_prompt = ChatPromptTemplate.from_messages([
355
+ ("system", system_message),
356
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
357
+ "What improvements can be made for a better search?")
358
+ ])
359
+
360
+ chain = refine_query_prompt | llm | StrOutputParser()
361
+
362
+ # Store refinement suggestions without modifying the original expanded query
363
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
364
+ print("query_feedback: ", query_feedback)
365
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
366
+ state['query_feedback'] = query_feedback
367
+ return state
368
+
369
+
370
+
371
+ def should_continue_groundedness(state):
372
+ """Decides if groundedness is sufficient or needs improvement."""
373
+ print("---------should_continue_groundedness---------")
374
+ print("groundedness loop count: ", state['groundedness_loop_count'])
375
+
376
+ # Threshold logic: groundedness score should be at least 0.8
377
+ if state['groundedness_score'] >= 0.8:
378
+ print("Moving to precision")
379
+ return "check_precision"
380
+ else:
381
+ # Allow a maximum of 2 refinement loops
382
+ if state['groundedness_loop_count'] > state['loop_max_iter']:
383
+ print("Maximum groundedness iterations reached")
384
+ return "max_iterations_reached"
385
+ else:
386
+ print("---------Groundedness Score Threshold Not Met. Refining Response-----------")
387
+ return "refine_response"
388
+
389
+
390
+
391
+ def should_continue_precision(state: Dict) -> str:
392
+ """Decides if precision is sufficient or needs improvement."""
393
+ print("---------should_continue_precision---------")
394
+ print("precision loop count: ", state['precision_loop_count'])
395
+
396
+ # Threshold for acceptable precision score
397
+ if state['precision_score'] >= 0.8:
398
+ return "pass" # Complete the workflow
399
+ else:
400
+ # Check if maximum refinement attempts have been reached
401
+ if state['precision_loop_count'] > state['loop_max_iter']:
402
+ return "max_iterations_reached"
403
+ else:
404
+ print("---------Precision Score Threshold Not met. Refining Query-----------")
405
+ return "refine_query"
406
+
407
+
408
+
409
+
410
+ def max_iterations_reached(state: Dict) -> Dict:
411
+
412
+ """Handles the case where max iterations are reached."""
413
+ print("---------max_iterations_reached---------")
414
+ state['response'] = "We need more context to provide an accurate answer."
415
+
416
+ return state
417
+
418
+
419
+
420
+ from langgraph.graph import END, StateGraph, START
421
+
422
+ def create_workflow() -> StateGraph:
423
+ """Creates the updated workflow for the AI nutrition agent."""
424
+ workflow = StateGraph(dict) # Workflow state is a dictionary
425
+
426
+ # Add processing nodes
427
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
428
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
429
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
430
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
431
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
432
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
433
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
434
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
435
+
436
+ # Main flow edges
437
+ workflow.add_edge(START, "expand_query")
438
+ workflow.add_edge("expand_query", "retrieve_context")
439
+ workflow.add_edge("retrieve_context", "craft_response")
440
+ workflow.add_edge("craft_response", "score_groundedness")
441
+
442
+ # Groundedness logic
443
+ workflow.add_conditional_edges(
444
+ "score_groundedness",
445
+ should_continue_groundedness,
446
+ {
447
+ "check_precision": "check_precision",
448
+ "refine_response": "refine_response",
449
+ "max_iterations_reached": "max_iterations_reached"
450
+ }
451
+ )
452
+
453
+ # Edge to reprocess refined response
454
+ workflow.add_edge("refine_response", "craft_response")
455
+
456
+ # Precision logic
457
+ workflow.add_conditional_edges(
458
+ "check_precision",
459
+ should_continue_precision,
460
+ {
461
+ "pass": END,
462
+ "refine_query": "refine_query",
463
+ "max_iterations_reached": "max_iterations_reached"
464
+ }
465
+ )
466
+
467
+ # Edge to re-expand refined query and reenter flow
468
+ workflow.add_edge("refine_query", "expand_query")
469
+ workflow.add_edge("max_iterations_reached", END)
470
+
471
+ return workflow
472
+
473
+
474
+
475
+
476
+ #=========================== Defining the agentic rag tool ====================#
477
+ WORKFLOW_APP = create_workflow().compile()
478
+ @tool
479
+ def agentic_rag(query: str) -> Dict[str, Any]:
480
+ """
481
+ Runs the RAG-based agent with conversation history for context-aware responses.
482
+ """
483
+ if not query or not isinstance(query, str):
484
+ return {"error": "Invalid or empty query provided"}
485
+ inputs = {
486
+ "query": query,
487
+ "expanded_query": "", #Initialized as an empty string since the expand_query function will populate this field with the reformulated query based on the original query
488
+ "context": [], # Retrieved documents (initially empty)
489
+ "response": "", #Initialized as an empty string since the craft_response function will generate the AI response and store it here
490
+ "precision_score": 0.0, #Initialized as 0.0 since the check_precision function will compute and assign a precision score between 0 and 1.
491
+ "groundedness_score": 0.0, #Initialized as 0.0 since the score_groundedness function will compute and assign a groundedness score between 0 and 1.
492
+ "groundedness_loop_count": 0, #Initialized as 0 to track the number of groundedness refinement loops, incremented in score_groundedness.
493
+ "precision_loop_count": 0, #Initialized as 0 to track the number of precision refinement loops, incremented in check_precision.
494
+ "feedback": "", #Initialized as an empty string since the refine_response function will populate this with suggestions for improving the response.
495
+ "query_feedback": "", #Initialized as an empty string since the refine_query function will populate this with suggestions for improving the expanded query.
496
+ "loop_max_iter": 5
497
+ }
498
+ output = WORKFLOW_APP.invoke(inputs)
499
+
500
+ return output
501
+
502
+
503
+ #================================ Guardrails ===========================#
504
+ llama_guard_client = Groq(api_key=llama_api_key)
505
+ # Function to filter user input with Llama Guard
506
+ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
507
+ """
508
+ Filters user input using Llama Guard to ensure it is safe.
509
+
510
+ Parameters:
511
+ - user_input: The input provided by the user.
512
+ - model: The Llama Guard model to be used for filtering (default is "meta-llama/llama-guard-4-12bb").
513
+
514
+ Returns:
515
+ - The filtered and safe input.
516
+ """
517
+ try:
518
+ # Create a request to Llama Guard to filter the user input
519
+ response = llama_guard_client.chat.completions.create(
520
+ messages=[{"role": "user", "content": user_input}],
521
+ model=model,
522
+ )
523
+ # Return the filtered input
524
+ return response.choices[0].message.content.strip()
525
+ except Exception as e:
526
+ print(f"Error with Llama Guard: {e}")
527
+ return None
528
+
529
+
530
+ #============================= Adding Memory to the agent using mem0 ===============================#
531
+
532
+ class NutritionBot:
533
+ def __init__(self):
534
+ """
535
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
536
+ """
537
+ # Initialize a memory client to store and retrieve customer interactions
538
+ self.memory = MemoryClient(api_key=os.environ.get("MEM0_API_KEY")) # Complete the code to define the memory client API key
539
+
540
+ # Initialize the OpenAI client using the provided credentials
541
+ self.client = ChatOpenAI(
542
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
543
+ api_key=os.environ.get("API_KEY"), # API key for authentication
544
+ openai_api_base = os.environ.get("OPENAI_API_BASE"),
545
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
546
+ )
547
+
548
+ # Define tools available to the chatbot, such as web search
549
+ tools = [agentic_rag]
550
+
551
+ # Define the system prompt to set the behavior of the chatbot
552
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
553
+ Guidelines for Interaction:
554
+ Maintain a polite, professional, and reassuring tone.
555
+ Show genuine empathy for customer concerns and health challenges.
556
+ Reference past interactions to provide personalized and consistent advice.
557
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
558
+ Ensure consistent and accurate information across conversations.
559
+ If any detail is unclear or missing, proactively ask for clarification.
560
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
561
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
562
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
563
+
564
+ """
565
+
566
+ # Build the prompt template for the agent
567
+ prompt = ChatPromptTemplate.from_messages([
568
+ ("system", system_prompt), # System instructions
569
+ ("human", "{input}"), # Placeholder for human input
570
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
571
+ ])
572
+
573
+ # Create an agent capable of interacting with tools and executing tasks
574
+ agent = create_tool_calling_agent(self.client, tools, prompt)
575
+
576
+ # Wrap the agent in an executor to manage tool interactions and execution flow
577
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
578
+
579
+
580
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
581
+ """
582
+ Store customer interaction in memory for future reference.
583
+
584
+ Args:
585
+ user_id (str): Unique identifier for the customer.
586
+ message (str): Customer's query or message.
587
+ response (str): Chatbot's response.
588
+ metadata (Dict, optional): Additional metadata for the interaction.
589
+ """
590
+ if metadata is None:
591
+ metadata = {}
592
+
593
+ # Add a timestamp to the metadata for tracking purposes
594
+ metadata["timestamp"] = datetime.now().isoformat()
595
+
596
+ # Format the conversation for storage
597
+ conversation = [
598
+ {"role": "user", "content": message},
599
+ {"role": "assistant", "content": response}
600
+ ]
601
+
602
+ # Store the interaction in the memory client
603
+ self.memory.add(
604
+ conversation,
605
+ user_id=user_id,
606
+ output_format="v1.1",
607
+ metadata=metadata
608
+ )
609
+
610
+
611
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
612
+ """
613
+ Retrieve past interactions relevant to the current query.
614
+
615
+ Args:
616
+ user_id (str): Unique identifier for the customer.
617
+ query (str): The customer's current query.
618
+
619
+ Returns:
620
+ List[Dict]: A list of relevant past interactions.
621
+ """
622
+ return self.memory.search(
623
+ query=query, # Search for interactions related to the query
624
+ user_id=user_id, # Restrict search to the specific user
625
+ limit=5 # Complete the code to define the limit for retrieved interactions
626
+ )
627
+
628
+
629
+ def handle_customer_query(self, user_id: str, query: str) -> str:
630
+ """
631
+ Process a customer's query and provide a response, taking into account past interactions.
632
+
633
+ Args:
634
+ user_id (str): Unique identifier for the customer.
635
+ query (str): Customer's query.
636
+
637
+ Returns:
638
+ str: Chatbot's response.
639
+ """
640
+
641
+ # Retrieve relevant past interactions for context
642
+ relevant_history = self.get_relevant_history(user_id, query)
643
+
644
+ # Build a context string from the relevant history
645
+ context = "Previous relevant interactions:\n"
646
+ for memory in relevant_history:
647
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
648
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
649
+ context += "---\n"
650
+
651
+ # Print context for debugging purposes
652
+ print("Context: ", context)
653
+
654
+ # Prepare a prompt combining past context and the current query
655
+ prompt = f"""
656
+ Context:
657
+ {context}
658
+
659
+ Current customer query: {query}
660
+
661
+ Provide a helpful response that takes into account any relevant past interactions.
662
+ """
663
+
664
+ # Generate a response using the agent
665
+ response = self.agent_executor.invoke({"input": prompt})
666
+
667
+ # Store the current interaction for future reference
668
+ self.store_customer_interaction(
669
+ user_id=user_id,
670
+ message=query,
671
+ response=response["output"],
672
+ metadata={"type": "support_query"}
673
+ )
674
+
675
+ # Return the chatbot's response
676
+ return response['output']
677
+
678
+
679
+ #=====================User Interface using streamlit ===========================#
680
+ def nutrition_disorder_streamlit():
681
+ """
682
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
683
+ """
684
+ st.title("Nutrition Disorder Specialist")
685
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
686
+ st.write("Type 'exit' to end the conversation.")
687
+
688
+ # Initialize session state for chat history and user_id if they don't exist
689
+ if 'chat_history' not in st.session_state:
690
+ st.session_state.chat_history = []
691
+ if 'user_id' not in st.session_state:
692
+ st.session_state.user_id = None
693
+
694
+ # Login form: Only if user is not logged in
695
+ if st.session_state.user_id is None:
696
+ with st.form("login_form", clear_on_submit=True):
697
+ user_id = st.text_input("Please enter your name to begin:")
698
+ submit_button = st.form_submit_button("Login")
699
+ if submit_button and user_id:
700
+ st.session_state.user_id = user_id
701
+ st.session_state.chat_history.append({
702
+ "role": "assistant",
703
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
704
+ })
705
+ st.session_state.login_submitted = True
706
+ if st.session_state.get("login_submitted", False):
707
+ st.session_state.pop("login_submitted")
708
+ st.rerun()
709
+ else:
710
+ for message in st.session_state.chat_history:
711
+ with st.chat_message(message["role"]):
712
+ st.write(message["content"])
713
+
714
+ # === Filled Blanks ===
715
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...")
716
+
717
+ if user_query:
718
+ if user_query.lower() == "exit":
719
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
720
+ with st.chat_message("user"):
721
+ st.write("exit")
722
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
723
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
724
+ with st.chat_message("assistant"):
725
+ st.write(goodbye_msg)
726
+ st.session_state.user_id = None
727
+ st.rerun()
728
+ return
729
+
730
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
731
+ with st.chat_message("user"):
732
+ st.write(user_query)
733
+
734
+ filtered_result = filter_input_with_llama_guard(user_query)
735
+ filtered_result = filtered_result.replace("\n", " ")
736
+
737
+ if filtered_result in ["safe", "unsafe S7", "unsafe S6"]:
738
+ try:
739
+ if 'chatbot' not in st.session_state:
740
+ st.session_state.chatbot = NutritionBot()
741
+ response = st.session_state.chatbot.handle_customer_query(
742
+ st.session_state.user_id, user_query)
743
+ st.write(response)
744
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
745
+ except Exception as e:
746
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
747
+ st.write(error_msg)
748
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
749
+ else:
750
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
751
+ st.write(inappropriate_msg)
752
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
753
+
754
+ if __name__ == "__main__":
755
+ nutrition_disorder_streamlit()