hanshan1988 commited on
Commit
202e355
·
verified ·
1 Parent(s): eb0a110

added langgraph implementation

Browse files
Files changed (1) hide show
  1. app.py +179 -18
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import asyncio
4
  import requests
@@ -11,21 +12,175 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
 
12
  # --- Basic Agent Definition ---
13
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
14
- from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
15
- from llama_index.core.tools import FunctionTool
16
- from llama_index.core.agent.workflow import AgentWorkflow
17
- from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
18
- from llama_index.tools.wikipedia import WikipediaToolSpec
 
 
 
 
 
 
 
 
19
 
20
  # Initialize the Hugging Face model
21
- llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen2.5-Coder-32B-Instruct")
22
 
23
- # Initialize the DuckDuckGo search tool
24
- tool_spec = DuckDuckGoSearchToolSpec()
25
- search_tool = FunctionTool.from_defaults(tool_spec.duckduckgo_full_search)
26
- # Initialize the wikipedia tool
27
- wiki_spec = WikipediaToolSpec()
28
- wiki_tools = wiki_spec.to_tool_list()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  class BasicAgent:
31
  def __init__(self):
@@ -35,13 +190,19 @@ class BasicAgent:
35
  # fixed_answer = "This is a default answer."
36
  # print(f"Agent returning fixed answer: {fixed_answer}")
37
  # Create agent with all the tools
38
- agent = AgentWorkflow.from_tools_or_functions(
39
- wiki_tools,
40
- llm=llm
41
- )
42
  # Example query agent might receive
43
- fixed_answer = await agent.run(question)
44
- return fixed_answer
 
 
 
 
 
 
 
 
 
45
 
46
  def run_and_submit_all( profile: gr.OAuthProfile | None):
47
  """
 
1
  import os
2
+ import time
3
  import gradio as gr
4
  import asyncio
5
  import requests
 
12
 
13
  # --- Basic Agent Definition ---
14
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
15
+
16
+ from typing import TypedDict, List, Dict, Any, Optional, Annotated
17
+ from langgraph.graph import StateGraph, START, END
18
+ from langchain_openai import ChatOpenAI
19
+ # from langchain_huggingface.llms import HuggingFaceEndpoint
20
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
21
+ from langgraph.graph.message import add_messages
22
+ from langchain_community.utilities import WikipediaAPIWrapper
23
+ from langchain_community.tools import WikipediaQueryRun
24
+ from langchain_community.document_loaders import YoutubeLoader, WebBaseLoader
25
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
26
+ from langgraph.prebuilt import ToolNode, tools_condition
27
+ from langchain.tools import tool
28
 
29
  # Initialize the Hugging Face model
 
30
 
31
+ llm = HuggingFaceEndpoint(
32
+ repo_id="Qwen/Qwen3-30B-A3B", # "Qwen/Qwen2.5-72B-Instruct",
33
+ provider="nebius", # "hf-inference",
34
+ max_new_tokens=8192,
35
+ do_sample=False,
36
+ # temperature=0.,
37
+ )
38
+
39
+ chat_model = ChatHuggingFace(llm=llm)
40
+
41
+ # Define tools
42
+
43
+
44
+ @tool
45
+ def fetch_website(url:str) -> str:
46
+ """Fetch the content of a website.
47
+ Args:
48
+ url: The URL of the website to fetch.
49
+ Returns:
50
+ The title and content of the website.
51
+ """
52
+ loader = WebBaseLoader(url)
53
+ docs = loader.load()
54
+ return docs[0].page_content
55
+
56
+ @tool
57
+ def ask_wiki(query: str) -> str:
58
+ """Retrieve information from Wikipedia based on a user query.
59
+ Args:
60
+ query: A user query.
61
+ Returns:
62
+ A single string containing the retrieved article from Wikipedia.
63
+ """
64
+ if not query.strip():
65
+ return "Please provide a valid query."
66
+ try:
67
+ wiki_toolapi_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=1000)
68
+ wiki_tool = WikipediaQueryRun(api_wrapper=wiki_toolapi_wrapper)
69
+ result = wiki_tool.run(query)
70
+ return result
71
+ except Exception as e:
72
+ return f"Error retrieving information: {str(e)}"
73
+
74
+ @tool
75
+ def youtube_transcript(url: str) -> str:
76
+ """Retrieve transcript from Youtube based url.
77
+ Args:
78
+ url: input youtube url.
79
+ Returns:
80
+ A single string containing the transcript of the youtube videos.
81
+ """
82
+ max_attempts = 5 # Set a maximum number of attempts
83
+ attempts = 0
84
+ loader = YoutubeLoader.from_youtube_url(url, add_video_info=False)
85
+ while attempts < max_attempts:
86
+ try:
87
+ docs = loader.load()
88
+ return docs[0].page_content
89
+ except Exception as e:
90
+ attempts += 1
91
+ print(f"Attempt {attempts} failed: {e}")
92
+ # Optionally add a delay before retrying
93
+ time.sleep(1) # Import the time module
94
+ return "Failed to retrieve transcript after multiple attempts."
95
+
96
+ # @tool
97
+ # def divide(a: int, b: int) -> float:
98
+ # """Divide a and b for occasional calculations.
99
+ # Args:
100
+ # a: integer
101
+ # b: integer
102
+ # Returns:
103
+ # A single float containing the result of the division.
104
+ # """
105
+ # return a / b
106
+
107
+ # Equip llm with tools
108
+ tools_list = [
109
+ fetch_website,
110
+ ask_wiki,
111
+ youtube_transcript,
112
+ ]
113
+
114
+ llm_with_tools = chat_model.bind_tools(
115
+ tools_list,
116
+ # parallel_tool_calls=False
117
+ )
118
+
119
+ # Define Agent Workflow
120
+
121
+ class AgentState(TypedDict):
122
+ messages: Annotated[list[AnyMessage], add_messages]
123
+
124
+
125
+ def assistant(state: AgentState):
126
+ # System message
127
+ textual_description_of_tool="""
128
+ fetch_website(url: str) -> str:
129
+ Fetch the content of a website.
130
+ Args:
131
+ url: The URL of the website to fetch.
132
+ Returns:
133
+ The title and content of the website.
134
+
135
+ ask_wiki(query: str) -> str:
136
+ Retreive information from Wikipedia based on a user query.
137
+ Args:
138
+ query: A user query.
139
+ Returns:
140
+ A single string containing the retrieved article from Wikipedia.
141
+
142
+ youtube_transcript(url: str) -> str:
143
+ Args:
144
+ url: input youtube url.
145
+ Returns:
146
+ A single string containing the transcript of the youtube videos.
147
+ """
148
+ sys_msg = SystemMessage(
149
+ content=f"You are a helpful assistant at answering user questions. You can access provided tools:\n{textual_description_of_tool}\n"
150
+ )
151
+
152
+ return {
153
+ "messages": [llm_with_tools.invoke([sys_msg] + state["messages"])],
154
+ }
155
+
156
+ # Build the StateGraph for the agent
157
+ # The graph
158
+ builder = StateGraph(AgentState)
159
+
160
+ # Define nodes: these do the work
161
+ builder.add_node("assistant", assistant)
162
+ builder.add_node("tools", ToolNode(tools_list))
163
+
164
+ # Define edges: these determine how the control flow moves
165
+ builder.add_edge(START, "assistant")
166
+ builder.add_conditional_edges(
167
+ "assistant",
168
+ # If the latest message requires a tool, route to tools
169
+ # Otherwise, provide a direct response
170
+ tools_condition,
171
+ )
172
+ builder.add_edge("tools", "assistant")
173
+ agent_graph = builder.compile()
174
+
175
+ messages = [
176
+ HumanMessage(
177
+ # content="Who is Barack Obama?"
178
+ # content="Divide 6790 by 5"
179
+ content=user_prompt
180
+ )
181
+ ]
182
+ messages = agent_graph.invoke({"messages": messages}, config={"callbacks": [langfuse_handler]})
183
+
184
 
185
  class BasicAgent:
186
  def __init__(self):
 
190
  # fixed_answer = "This is a default answer."
191
  # print(f"Agent returning fixed answer: {fixed_answer}")
192
  # Create agent with all the tools
193
+
 
 
 
194
  # Example query agent might receive
195
+ # fixed_answer = await agent.run(question)
196
+ messages = [
197
+ HumanMessage(
198
+ # content="Who is Barack Obama?"
199
+ # content="Divide 6790 by 5"
200
+ content=question
201
+ )
202
+ ]
203
+ response_text = messages['messages'][-1].content
204
+ return response_text.split('</think>')[-1]
205
+
206
 
207
  def run_and_submit_all( profile: gr.OAuthProfile | None):
208
  """