rakesh-dvg commited on
Commit
bcd63da
·
verified ·
1 Parent(s): 3aa06c2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +16 -30
agent.py CHANGED
@@ -1,36 +1,22 @@
1
- from typing import TypedDict, Annotated
2
- from langchain_core.messages import AnyMessage, HumanMessage
3
- from langgraph.graph.message import add_messages
4
- from langgraph.graph import START, StateGraph
5
- from langgraph.prebuilt import ToolNode, tools_condition
6
  from transformers import pipeline
7
- from tools import search_tool, weather_info_tool, hub_stats_tool
8
- from dotenv import load_dotenv
9
-
10
- load_dotenv()
11
-
12
- # Initialize conversational pipeline (DialoGPT)
13
- chat_pipe = pipeline("conversational", model="microsoft/DialoGPT-medium")
14
-
15
- tools = [search_tool, weather_info_tool, hub_stats_tool]
16
 
17
- # Wrapper function to call conversational pipeline
18
- def chat_llm(messages: list[AnyMessage]) -> str:
19
- last_message = messages[-1].content if messages else ""
20
- response = chat_pipe(last_message)
21
- return response[0]['generated_text']
22
 
23
  class AgentState(TypedDict):
24
- messages: Annotated[list[AnyMessage], add_messages]
25
 
26
  def assistant(state: AgentState):
27
- reply = chat_llm(state["messages"])
28
- return {"messages": [HumanMessage(content=reply)]}
29
-
30
- builder = StateGraph(AgentState)
31
- builder.add_node("assistant", assistant)
32
- builder.add_node("tools", ToolNode(tools))
33
- builder.add_edge(START, "assistant")
34
- builder.add_conditional_edges("assistant", tools_condition)
35
- builder.add_edge("tools", "assistant")
36
- alfred = builder.compile()
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
+ from langchain_core.messages import AnyMessage, HumanMessage
3
+ from typing import TypedDict, List
 
 
 
 
 
 
 
4
 
5
+ # Load text-generation pipeline with DialoGPT-medium (chat-like model)
6
+ chat_pipe = pipeline("text-generation", model="microsoft/DialoGPT-medium")
 
 
 
7
 
8
  class AgentState(TypedDict):
9
+ messages: List[AnyMessage]
10
 
11
  def assistant(state: AgentState):
12
+ # Extract last user message content
13
+ last_message = state["messages"][-1].content if state["messages"] else ""
14
+
15
+ # Generate response from the model
16
+ outputs = chat_pipe(last_message, max_length=100, do_sample=True, top_p=0.9, temperature=0.8)
17
+ generated_text = outputs[0]['generated_text']
18
+
19
+ # Remove the prompt from output if present
20
+ response = generated_text[len(last_message):].strip()
21
+
22
+ return {"messages": [HumanMessage(content=response)]}