rakesh-dvg's picture
Update agent.py
bcd63da verified
raw
history blame
856 Bytes
from transformers import pipeline
from langchain_core.messages import AnyMessage, HumanMessage
from typing import TypedDict, List
# Load text-generation pipeline with DialoGPT-medium (chat-like model)
chat_pipe = pipeline("text-generation", model="microsoft/DialoGPT-medium")
class AgentState(TypedDict):
messages: List[AnyMessage]
def assistant(state: AgentState):
# Extract last user message content
last_message = state["messages"][-1].content if state["messages"] else ""
# Generate response from the model
outputs = chat_pipe(last_message, max_length=100, do_sample=True, top_p=0.9, temperature=0.8)
generated_text = outputs[0]['generated_text']
# Remove the prompt from output if present
response = generated_text[len(last_message):].strip()
return {"messages": [HumanMessage(content=response)]}