hsila's picture
add system message, ruff format, change pro to flash
32c39ee
import os
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
load_dotenv()
SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and output only your final answer, no prefixes, suffixes, or extra text. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
@tool
def add(a: float, b: float) -> float:
"""Add two numbers together.
Args:
a: First number
b: Second number
"""
return a + b
@tool
def subtract(a: float, b: float) -> float:
"""Subtract b from a.
Args:
a: Number to subtract from
b: Number to subtract
"""
return a - b
@tool
def multiply(a: float, b: float) -> float:
"""Multiply two numbers together.
Args:
a: First number
b: Second number
"""
return a * b
@tool
def divide(a: float, b: float) -> float:
"""Divide a by b.
Args:
a: Dividend
b: Divisor
"""
if b == 0:
return "Error: Division by zero"
return a / b
@tool
def modulo(a: float, b: float) -> float:
"""Return the remainder of a divided by b.
Args:
a: Dividend
b: Divisor
"""
if b == 0:
return "Error: Division by zero"
return a % b
@tool
def power(a: float, b: float) -> float:
"""Raise a to the power of b.
Args:
a: Base number
b: Exponent
"""
return a**b
@tool
def square_root(a: float) -> float:
"""Calculate the square root of a number.
Args:
a: Number to calculate square root of
"""
if a < 0:
return "Error: Cannot calculate square root of negative number"
return a**0.5
@tool
def web_search(query: str) -> str:
"""Search the web for current information and facts.
Args:
query: Search query string
"""
try:
search_tool = TavilySearchResults(max_results=3)
results = search_tool.invoke(query)
if not results:
return "No search results found."
formatted_results = []
for i, result in enumerate(results, 1):
title = result.get("title", "No title")
content = result.get("content", "No content")
url = result.get("url", "No URL")
formatted_results.append(f"{i}. {title}\n{content}\nSource: {url}")
return "\n\n ==== \n\n".join(formatted_results)
except Exception as e:
return f"Error performing search: {str(e)}"
@tool
def wikipedia_search(query: str) -> str:
"""Search Wikipedia for factual information.
Args:
query: Wikipedia search query
"""
try:
loader = WikipediaLoader(query=query, load_max_docs=2)
docs = loader.load()
if not docs:
return "No Wikipedia results found."
formatted_docs = []
for i, doc in enumerate(docs, 1):
title = doc.metadata.get("title", "No title")
content = doc.page_content
formatted_docs.append(f"{i}. {title}\n{content}")
return "\n\n ==== \n\n".join(formatted_docs)
except Exception as e:
return f"Error searching Wikipedia: {str(e)}"
tools = [
add,
subtract,
multiply,
divide,
modulo,
power,
square_root,
web_search,
wikipedia_search,
]
def get_llm():
"""Initialize the llm"""
return ChatGoogleGenerativeAI(
model="gemini-2.5-flash", temperature=0, api_key=os.getenv("GEMINI_API_KEY")
)
def call_model(state: MessagesState):
"""Call the LLM with the current state.
Args:
state: Current state containing messages
"""
llm = get_llm()
llm_with_tools = llm.bind_tools(tools)
messages = state["messages"]
if not messages or not isinstance(messages[0], SystemMessage):
messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
def build_graph():
"""Build and return the LangGraph workflow."""
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode(tools))
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", tools_condition)
workflow.add_edge("tools", "agent")
return workflow.compile()
if __name__ == "__main__":
graph = build_graph()
test_message = [HumanMessage(content="What is 15 + 27?")]
result = graph.invoke({"messages": test_message})
print(f"Test result: {result['messages'][-1].content}")