|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
from transformers import pipeline
|
|
import torch
|
|
import os
|
|
|
|
|
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
model_path = os.path.join(base_dir, "chatbot")
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
|
|
model = GPT2LMHeadModel.from_pretrained(model_path)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
|
|
generator = pipeline(
|
|
"text-generation",
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
device=0 if torch.cuda.is_available() else -1
|
|
)
|
|
|
|
|
|
def chat(temp=0.5):
|
|
print(f"\n🤖 Chatbot is ready! (temperature={temp}) — type 'exit' to quit.")
|
|
context = ""
|
|
while True:
|
|
user_input = input("You: ")
|
|
if user_input.lower() == "exit":
|
|
break
|
|
context += f"A: {user_input}\nB:"
|
|
result = generator(
|
|
context,
|
|
max_length=len(tokenizer.encode(context)) + 50,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
do_sample=True,
|
|
top_k=50,
|
|
top_p=0.95,
|
|
temperature=temp
|
|
)[0]["generated_text"]
|
|
reply = result[len(context):].split("\n")[0].strip()
|
|
print(f"Bot: {reply}")
|
|
context += f"{reply}\n"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
chat(temp=0.8)
|
|
|