Update agent.py
Browse files
agent.py
CHANGED
@@ -1,36 +1,22 @@
|
|
1 |
-
from typing import TypedDict, Annotated
|
2 |
-
from langchain_core.messages import AnyMessage, HumanMessage
|
3 |
-
from langgraph.graph.message import add_messages
|
4 |
-
from langgraph.graph import START, StateGraph
|
5 |
-
from langgraph.prebuilt import ToolNode, tools_condition
|
6 |
from transformers import pipeline
|
7 |
-
from
|
8 |
-
from
|
9 |
-
|
10 |
-
load_dotenv()
|
11 |
-
|
12 |
-
# Initialize conversational pipeline (DialoGPT)
|
13 |
-
chat_pipe = pipeline("conversational", model="microsoft/DialoGPT-medium")
|
14 |
-
|
15 |
-
tools = [search_tool, weather_info_tool, hub_stats_tool]
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
last_message = messages[-1].content if messages else ""
|
20 |
-
response = chat_pipe(last_message)
|
21 |
-
return response[0]['generated_text']
|
22 |
|
23 |
class AgentState(TypedDict):
|
24 |
-
messages:
|
25 |
|
26 |
def assistant(state: AgentState):
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import pipeline
|
2 |
+
from langchain_core.messages import AnyMessage, HumanMessage
|
3 |
+
from typing import TypedDict, List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
# Load text-generation pipeline with DialoGPT-medium (chat-like model)
|
6 |
+
chat_pipe = pipeline("text-generation", model="microsoft/DialoGPT-medium")
|
|
|
|
|
|
|
7 |
|
8 |
class AgentState(TypedDict):
|
9 |
+
messages: List[AnyMessage]
|
10 |
|
11 |
def assistant(state: AgentState):
|
12 |
+
# Extract last user message content
|
13 |
+
last_message = state["messages"][-1].content if state["messages"] else ""
|
14 |
+
|
15 |
+
# Generate response from the model
|
16 |
+
outputs = chat_pipe(last_message, max_length=100, do_sample=True, top_p=0.9, temperature=0.8)
|
17 |
+
generated_text = outputs[0]['generated_text']
|
18 |
+
|
19 |
+
# Remove the prompt from output if present
|
20 |
+
response = generated_text[len(last_message):].strip()
|
21 |
+
|
22 |
+
return {"messages": [HumanMessage(content=response)]}
|