Spaces:
Sleeping
Sleeping
| """ | |
| LLM Agent Graph Implementation | |
| ============================= | |
| This module defines a graph-based LLM agent workflow with various tools and retrieval capabilities. | |
| The agent can: | |
| - Perform mathematical operations | |
| - Search Wikipedia, web, and arXiv | |
| - Retrieve similar questions from a vector database | |
| - Process user queries using different LLM providers | |
| Components: | |
| - Tool definitions: Math operations, search tools | |
| - Vector database retrieval | |
| - Graph construction with different LLM options | |
| - Workflow management with LangGraph | |
| """ | |
| import os | |
| import logging | |
| from typing import Dict, List, Union, Optional, Any, Callable | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition, ToolNode | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_core.tools import tool | |
| from langchain.tools.retriever import create_retriever_tool | |
| from supabase.client import Client, create_client | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # =================== | |
| # Math Operation Tools | |
| # =================== | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two integers and return the result. | |
| Args: | |
| a: First integer to multiply | |
| b: Second integer to multiply | |
| Returns: | |
| The product of a and b | |
| """ | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two integers and return the result. | |
| Args: | |
| a: First integer to add | |
| b: Second integer to add | |
| Returns: | |
| The sum of a and b | |
| """ | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract the second integer from the first and return the result. | |
| Args: | |
| a: Integer to subtract from | |
| b: Integer to subtract | |
| Returns: | |
| The difference (a - b) | |
| """ | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divide the first integer by the second and return the result. | |
| Args: | |
| a: Numerator (dividend) | |
| b: Denominator (divisor) | |
| Returns: | |
| The quotient (a / b) as a float | |
| Raises: | |
| ValueError: If b is zero (division by zero) | |
| """ | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Calculate the remainder when the first integer is divided by the second. | |
| Args: | |
| a: Dividend | |
| b: Divisor | |
| Returns: | |
| The remainder of a divided by b | |
| Raises: | |
| ValueError: If b is zero (modulo by zero) | |
| """ | |
| if b == 0: | |
| raise ValueError("Cannot calculate modulus with divisor zero.") | |
| return a % b | |
| # =================== | |
| # Search Tools | |
| # =================== | |
| def wiki_search(query: str) -> Dict[str, str]: | |
| """Search Wikipedia for a query and return formatted results. | |
| Args: | |
| query: The search term to look up on Wikipedia | |
| Returns: | |
| Dictionary with formatted Wikipedia search results | |
| """ | |
| logger.info(f"Searching Wikipedia for: {query}") | |
| try: | |
| search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| if not search_docs: | |
| return {"wiki_results": "No Wikipedia results found for this query."} | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| logger.info(f"Found {len(search_docs)} Wikipedia results") | |
| return {"wiki_results": formatted_search_docs} | |
| except Exception as e: | |
| logger.error(f"Error searching Wikipedia: {e}", exc_info=True) | |
| return {"wiki_results": f"Error searching Wikipedia: {str(e)}"} | |
| def web_search(query: str) -> Dict[str, str]: | |
| """Search the web using Tavily for a query and return formatted results. | |
| Args: | |
| query: The search term to look up on the web | |
| Returns: | |
| Dictionary with formatted web search results | |
| """ | |
| logger.info(f"Searching the web for: {query}") | |
| try: | |
| search_results = TavilySearchResults(max_results=3).invoke(query=query) | |
| if not search_results: | |
| return {"web_results": "No web results found for this query."} | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{result["url"]}">\n{result["content"]}\n</Document>' | |
| for result in search_results | |
| ] | |
| ) | |
| logger.info(f"Found {len(search_results)} web search results") | |
| return {"web_results": formatted_search_docs} | |
| except Exception as e: | |
| logger.error(f"Error searching the web: {e}", exc_info=True) | |
| return {"web_results": f"Error searching the web: {str(e)}"} | |
| def arxiv_search(query: str) -> Dict[str, str]: | |
| """Search arXiv for academic papers and return formatted results. | |
| Args: | |
| query: The search term to look up on arXiv | |
| Returns: | |
| Dictionary with formatted arXiv search results | |
| """ | |
| logger.info(f"Searching arXiv for: {query}") | |
| try: | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| if not search_docs: | |
| return {"arxiv_results": "No arXiv results found for this query."} | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["entry_id"]}" title="{doc.metadata.get("Title", "")}">\n{doc.page_content[:1000]}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| logger.info(f"Found {len(search_docs)} arXiv results") | |
| return {"arxiv_results": formatted_search_docs} | |
| except Exception as e: | |
| logger.error(f"Error searching arXiv: {e}", exc_info=True) | |
| return {"arxiv_results": f"Error searching arXiv: {str(e)}"} | |
| # =================== | |
| # Vector Store Setup | |
| # =================== | |
| def setup_vector_store() -> SupabaseVectorStore: | |
| """ | |
| Set up and configure the Supabase vector store for question retrieval. | |
| Returns: | |
| Configured SupabaseVectorStore instance | |
| Raises: | |
| ValueError: If required environment variables are missing | |
| """ | |
| # Check for required environment variables | |
| supabase_url = os.environ.get("SUPABASE_URL") | |
| supabase_key = os.environ.get("SUPABASE_SERVICE_KEY") | |
| if not supabase_url or not supabase_key: | |
| raise ValueError( | |
| "Missing required environment variables: SUPABASE_URL and/or SUPABASE_SERVICE_KEY" | |
| ) | |
| # Initialize embeddings model | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
| # Initialize Supabase client | |
| supabase_client: Client = create_client(supabase_url, supabase_key) | |
| # Create vector store | |
| vector_store = SupabaseVectorStore( | |
| client=supabase_client, | |
| embedding=embeddings, | |
| table_name="documents", | |
| query_name="match_documents_langchain", | |
| ) | |
| logger.info("Vector store initialized successfully") | |
| return vector_store | |
| # =================== | |
| # LLM Provider Setup | |
| # =================== | |
| def get_llm(provider: str = "google"): | |
| """ | |
| Initialize and return an LLM based on the specified provider. | |
| Args: | |
| provider: The LLM provider to use ('google', 'groq', or 'huggingface') | |
| Returns: | |
| Initialized LLM instance | |
| Raises: | |
| ValueError: If an invalid provider is specified | |
| """ | |
| if provider == "google": | |
| logger.info("Using Google Gemini as LLM provider") | |
| return ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-04-17", temperature=0) | |
| elif provider == "groq": | |
| logger.info("Using Groq as LLM provider with qwen-qwq-32b model") | |
| return ChatGroq(model="qwen-qwq-32b", temperature=0) | |
| elif provider == "huggingface": | |
| logger.info("Using Hugging Face as LLM provider with llama-2-7b-chat-hf model") | |
| return ChatHuggingFace( | |
| llm=HuggingFaceEndpoint( | |
| url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", | |
| temperature=0, | |
| ), | |
| ) | |
| else: | |
| available_providers = ['google', 'groq', 'huggingface'] | |
| raise ValueError(f"Invalid provider: '{provider}'. Choose from {available_providers}") | |
| # =================== | |
| # Graph Building | |
| # =================== | |
| def build_graph(provider: str = "groq"): | |
| """ | |
| Build and compile the agent workflow graph. | |
| This function creates a LangGraph workflow that includes: | |
| - A retriever node to find similar questions | |
| - An assistant node that uses an LLM to generate responses | |
| - A tools node for executing various tools | |
| Args: | |
| provider: The LLM provider to use ('google', 'groq', or 'huggingface') | |
| Returns: | |
| Compiled StateGraph ready for execution | |
| """ | |
| logger.info(f"Building agent graph with {provider} as LLM provider") | |
| # Load system prompt | |
| try: | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read() | |
| logger.info("Loaded system prompt from file") | |
| except FileNotFoundError: | |
| system_prompt = """You are a helpful AI assistant that answers questions accurately and concisely. | |
| Use the available tools when appropriate to find information or perform calculations. | |
| Always cite your sources when you use search tools.""" | |
| logger.warning("system_prompt.txt not found, using default system prompt") | |
| # Initialize system message | |
| sys_msg = SystemMessage(content=system_prompt) | |
| # Set up vector store and retriever tool | |
| try: | |
| vector_store = setup_vector_store() | |
| retriever_tool = create_retriever_tool( | |
| retriever=vector_store.as_retriever(), | |
| name="Question Search", | |
| description="A tool to retrieve similar questions from a vector store.", | |
| ) | |
| logger.info("Vector store retrieval tool initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to set up vector store: {e}", exc_info=True) | |
| retriever_tool = None | |
| # Define available tools | |
| tools = [ | |
| multiply, | |
| add, | |
| subtract, | |
| divide, | |
| modulus, | |
| wiki_search, | |
| web_search, | |
| arxiv_search, | |
| ] | |
| # Add retriever tool if available | |
| if retriever_tool: | |
| tools.append(retriever_tool) | |
| # Get LLM and bind tools | |
| llm = get_llm(provider) | |
| llm_with_tools = llm.bind_tools(tools) | |
| # Define graph nodes | |
| def assistant(state: MessagesState) -> Dict[str, List]: | |
| """ | |
| Assistant node that processes messages with the LLM. | |
| Args: | |
| state: Current message state | |
| Returns: | |
| Updated message state with LLM response | |
| """ | |
| return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
| def retriever(state: MessagesState) -> Dict[str, List]: | |
| """ | |
| Retriever node that finds similar questions from the vector store. | |
| Args: | |
| state: Current message state | |
| Returns: | |
| Updated message state with retrieved examples | |
| """ | |
| # Only use retrieval if vector_store is available | |
| if vector_store: | |
| try: | |
| similar_questions = vector_store.similarity_search(state["messages"][0].content) | |
| if similar_questions: | |
| example_msg = HumanMessage( | |
| content=f"Here I provide a similar question and answer for reference: \n\n{similar_questions[0].page_content}", | |
| ) | |
| return {"messages": [sys_msg] + state["messages"] + [example_msg]} | |
| except Exception as e: | |
| logger.error(f"Error in retriever node: {e}", exc_info=True) | |
| # If vector_store is unavailable or retrieval fails, just add system message | |
| return {"messages": [sys_msg] + state["messages"]} | |
| # Build graph | |
| builder = StateGraph(MessagesState) | |
| # Add nodes | |
| builder.add_node("retriever", retriever) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| # Add edges | |
| builder.add_edge(START, "retriever") | |
| builder.add_edge("retriever", "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| # Compile graph | |
| compiled_graph = builder.compile() | |
| logger.info("Agent graph compiled successfully") | |
| return compiled_graph | |
| # =================== | |
| # Testing | |
| # =================== | |
| if __name__ == "__main__": | |
| test_question = "When was the wiki entry of Boethius on De Philosophiae Consolatione first added?" | |
| # Build the graph | |
| logger.info("Starting test run") | |
| graph = build_graph(provider="groq") | |
| # Run the graph | |
| logger.info(f"Testing with question: {test_question}") | |
| messages = [HumanMessage(content=test_question)] | |
| result_messages = graph.invoke({"messages": messages}) | |
| # Display results | |
| logger.info("Test completed, printing messages:") | |
| for message in result_messages["messages"]: | |
| message.pretty_print() |