import torch import torch.nn.functional as F from model import MiniGPT from dataset import MiniBPETokenizr,SimpleTokenizr import json import os from tokenizers import Tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer #tokenizer = SimpleTokenizr() #tokenizer.load("./customchatbot-v1/trained-mini-gpt/tokenizer.json") tokenizer = Tokenizer.from_file("./trained-mini-gpt/tokenizer.json") # Load model model = MiniGPT(vocab_size=tokenizer.get_vocab_size()) #model.load_state_dict(torch.load("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth", map_location=device) if os.path.exists("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth") else torch.load("./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth", map_location=device)["model_state_dict"] ) checkpoint = torch.load("./trained-mini-gpt/mini-gpt.pth", map_location=device) model.load_state_dict(checkpoint) model.eval().to(device) totalparams = sum(p.numel() for p in model.parameters()) print(f"Model total params: {totalparams:,}") def sample_token(logits, temperature=1.0): logits = logits / temperature logits = torch.nan_to_num(logits, nan=-1e9) probs = F.softmax(logits, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(probs < 0): print("⚠️ Invalid probs detected. Using uniform fallback.") probs = torch.ones_like(probs) / probs.size(-1) return torch.multinomial(probs, num_samples=1).item() def generate_reply(prompt, max_tokens=100): tokens = tokenizer.encode(prompt).ids if not tokens: print("⚠️ Empty prompt after encoding.") return input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) generated = [] with torch.no_grad(): for _ in range(max_tokens): logits = model(input_ids) logits = logits[:, -1, :] next_token = sample_token(logits) generated.append(next_token) next_str = tokenizer.id_to_token(next_token) encoded_text = tokenizer.encode(next_str).ids decoded_text = tokenizer.decode(encoded_text) print(decoded_text, end=" ", flush=True) if next_str == "": break input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1) print() # Chat loop print("🧠 MiniGPT Chat (type 'exit' to quit')") while True: user_input = input("User: ") if user_input.lower() == "exit": break prompt = f"^User: {user_input}\nMiniGPT:" print("MiniGPT: ", end="", flush=True) generate_reply(prompt)