lisaterumi commited on
Commit
b6fb167
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +109 -0
agent.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ # from langchain_community.document_loaders import WikipediaLoader
12
+ # from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ # from langchain.tools.retriever import create_retriever_tool
17
+ # from supabase.client import Client, create_client
18
+ from langchain_core.messages import AIMessage
19
+ from difflib import SequenceMatcher
20
+ import time # Add time import
21
+
22
+ from tools import add, subtract, multiply, divide, modulus, wiki_search, web_search, arvix_search, search_metadata
23
+
24
+ # load the system prompt from the file
25
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
26
+ system_prompt = f.read()
27
+
28
+ # System message
29
+ sys_msg = SystemMessage(content=system_prompt)
30
+
31
+ tools = [
32
+ multiply,
33
+ add,
34
+ subtract,
35
+ divide,
36
+ modulus,
37
+ wiki_search,
38
+ web_search,
39
+ arvix_search,
40
+ search_metadata,
41
+ ]
42
+
43
+ # Build graph function
44
+ def build_graph(provider: str = "google"):
45
+ """Build the graph"""
46
+ # Load environment variables from .env file
47
+ if provider == "google":
48
+ # Google Gemini
49
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-05-20", temperature=1)
50
+ elif provider == "groq":
51
+ # Groq https://console.groq.com/docs/models
52
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
53
+ elif provider == "huggingface":
54
+ # TODO: Add huggingface endpoint
55
+ llm = ChatHuggingFace(
56
+ llm=HuggingFaceEndpoint(
57
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
58
+ temperature=0,
59
+ ),
60
+ )
61
+ else:
62
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
63
+
64
+ # Bind tools to LLM
65
+ llm_with_tools = llm.bind_tools(tools)
66
+
67
+ # Node
68
+ def assistant(state: MessagesState):
69
+ """Assistant node"""
70
+ messages = state["messages"]
71
+ # If we have retrieved information, use it
72
+ if len(messages) > 1 and "No matching results found in metadata" not in messages[-1].content:
73
+ # Create a new message list with proper structure
74
+ new_messages = [
75
+ SystemMessage(content="You are a helpful assistant. Use the following retrieved information to answer the question. If the information is relevant, use it directly. If not, use your own knowledge."),
76
+ HumanMessage(content=f"Question: {messages[-2].content}\n\nRetrieved Information:\n{messages[-1].content}")
77
+ ]
78
+ time.sleep(2) # Add 2 second sleep before tool call
79
+ return {"messages": [llm_with_tools.invoke(new_messages)]}
80
+ else:
81
+ # If no retrieved information, just use the original message
82
+ time.sleep(2) # Add 2 second sleep before tool call
83
+ return {"messages": [llm_with_tools.invoke(messages)]}
84
+
85
+ def retriever(state: MessagesState):
86
+ query = state["messages"][-1].content
87
+ result = search_metadata(query)
88
+ return {"messages": [AIMessage(content=result)]}
89
+
90
+ builder = StateGraph(MessagesState)
91
+ builder.add_node("retriever", retriever)
92
+ builder.add_node("assistant", assistant)
93
+ builder.add_node("tools", ToolNode(tools))
94
+
95
+ # Start with retriever
96
+ builder.set_entry_point("retriever")
97
+
98
+ # After retriever, go to assistant
99
+ builder.add_edge("retriever", "assistant")
100
+
101
+ # Assistant can either use tools or finish
102
+ builder.add_conditional_edges(
103
+ "assistant",
104
+ tools_condition,
105
+ )
106
+ builder.add_edge("tools", "assistant")
107
+
108
+ # Compile graph
109
+ return builder.compile()