|
from transformers import pipeline |
|
from langchain_core.messages import AnyMessage, HumanMessage |
|
from typing import TypedDict, List |
|
|
|
|
|
chat_pipe = pipeline("text-generation", model="microsoft/DialoGPT-medium") |
|
|
|
class AgentState(TypedDict): |
|
messages: List[AnyMessage] |
|
|
|
def assistant(state: AgentState): |
|
|
|
last_message = state["messages"][-1].content if state["messages"] else "" |
|
|
|
|
|
outputs = chat_pipe(last_message, max_length=100, do_sample=True, top_p=0.9, temperature=0.8) |
|
generated_text = outputs[0]['generated_text'] |
|
|
|
|
|
response = generated_text[len(last_message):].strip() |
|
|
|
return {"messages": [HumanMessage(content=response)]} |
|
|