import json import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tokenizers import Tokenizer from tqdm import tqdm import os import re from collections import Counter import multiprocessing from torch.utils.data import random_split multiprocessing.set_start_method("spawn", force=True) class ChatDataset(Dataset): def __init__(self, data, tokenizer, block_size=64): self.tokenizer = tokenizer self.block_size = block_size self.data = self.tokenize_data(data) def tokenize_data(self, data): chunks = [] with open(data, "r", encoding="utf-8") as f: for d in f: line = json.loads(d.strip()) # Fix duplicated instruction text = "^User: " + line["instruction"].strip() + " MiniGPT: " + line["output"].strip() + " " encoding = self.tokenizer.encode(text) tokens = encoding.ids # You confirmed your 10 examples are long enough, so no change to this filter. # If you were to use shorter data later, you'd need to reconsider this. if len(tokens) < self.block_size: print(f"Skipping short example (length {len(tokens)} < block_size {self.block_size}): {text[:50]}...") continue # 🎯 CHANGE 3: Use overlapping chunks (stride = 1) # This drastically increases the effective number of training samples # derived from your limited raw data. stride = 1 # Change this to 1 for max overlap, or self.block_size // 2 for moderate for i in range(0, len(tokens) - self.block_size + 1, stride): chunk = tokens[i:i + self.block_size] if len(chunk) == self.block_size: # Ensures only full blocks are added chunks.append(chunk) print(f"Dataset created with {len(chunks)} total training chunks.") # Added print return chunks def __len__(self): return len(self.data) def __getitem__(self, idx): chunk = self.data[idx] x = torch.tensor(chunk[:-1], dtype=torch.long) # Ensure dtype is long y = torch.tensor(chunk[1:], dtype=torch.long) # Ensure dtype is long return x, y # MiniBPETokenizr and SimpleTokenizr classes (no changes, but included for completeness) class MiniBPETokenizr: def __init__(self): self.stoi = {} self.itos = {} self.vocab_size = 0 def tokenize(self, text): text = text.lower().strip() words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text) return [list(w) + [''] if w.isalnum() else [w] for w in words] def get_stats(self, corpus): pairs = Counter() for tokens in corpus: for i in range(len(tokens) - 1): pairs[(tokens[i], tokens[i + 1])] += 1 return pairs def merge_vocab(self, corpus, pair_to_merge): bigram = re.escape(' '.join(pair_to_merge)) pattern = re.compile(r'(?", "", "", "^user:", "minigpt:"]) self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))} self.itos = {i: tok for tok, i in self.stoi.items()} self.vocab_size = len(self.stoi) def encode(self, text): tokens = sum(self.tokenize(text), []) output = [] i = 0 while i < len(tokens): j = len(tokens) while j > i: candidate = ''.join(tokens[i:j]) if candidate in self.stoi: output.append(self.stoi[candidate]) i = j break j -= 1 else: output.append(self.stoi.get("", 1)) i += 1 return output def decode(self, token_ids): tokens = [self.itos.get(i, "") for i in token_ids] text = ' '.join(t.replace('', '') for t in tokens if t not in {"", "", ""}) text = re.sub(r'\s([?.!,:;])', r'\1', text) return text.strip() def save(self, path): with open(path, "w", encoding="utf-8") as f: json.dump({"stoi": self.stoi, "itos": self.itos}, f) def load(self, path): with open(path, "r", encoding="utf-8") as f: data = json.load(f) self.stoi = {k: int(v) for k, v in data["stoi"].items()} self.itos = {int(v): k for k, v in self.stoi.items()} self.vocab_size = len(self.stoi) class SimpleTokenizr: def __init__(self): self.stoi = {} self.itos = {} def tokenize(self, text): return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]", text.lower()) def train(self, texts): vocab = set() for text in texts: tokens = self.tokenize(text) vocab.update(tokens) vocab.update(["", "", "", "^user :", "minigpt :", "MiniGPT :", ":"]) sorted_vocab = sorted(vocab) self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)} self.itos = {idx: token for token, idx in self.stoi.items()} def encode(self, text): tokens = self.tokenize(text) return [self.stoi.get(tok, self.stoi[""]) for tok in tokens] + [self.stoi[""]] def decode(self, token_ids): tokens = [self.itos.get(i, "") for i in token_ids] clean_tokens = [tok for tok in tokens if tok not in {"", "", ""}] text = '' for i, tok in enumerate(clean_tokens): if re.match(r"[.,!?;:]", tok): text += tok elif i > 0: text += ' ' + tok else: text += tok return text.strip().capitalize() def save(self, path): with open(path, "w", encoding="utf-8") as f: json.dump({"stoi": self.stoi, "itos": self.itos}, f) def load(self, path): with open(path, "r", encoding="utf-8") as f: data = json.load(f) self.stoi = {k: int(v) for k, v in data["stoi"].items()} self.itos = {int(k): v for v, k in self.stoi.items()} def __len__(self): return len(self.stoi) @property def vocab_size(self): return len(self.stoi) def validate(model, dataloader, device): model.eval() total_loss, correct, total = 0, 0, 0 with torch.no_grad(): for x, y in dataloader: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) total_loss += loss.item() preds = torch.argmax(logits, dim=-1) correct += (preds == y).sum().item() total += y.numel() avg_loss = total_loss / len(dataloader) accuracy = 100 * correct / total return avg_loss, accuracy # 🎯 CHANGE 4: Add learning_rate parameter to the train function def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0, learning_rate=5e-5): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # 🔀 Proper train/val split val_size = int(0.1 * len(dataset)) train_size = len(dataset) - val_size train_set, val_set = random_split(dataset, [train_size, val_size]) # 🎯 CHANGE 5: Reduce batch_size and num_workers for debugging tiny datasets # Batch size 1 or equal to len(train_set) is ideal for testing memorization # num_workers=0 simplifies debugging. train_loader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=0) val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0) # Use the passed learning_rate optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) checkpoint_path = "./trained-mini-gpt/checkpoint-mini-gpt.pth" if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] start_step = checkpoint["step"] else: model.load_state_dict(checkpoint) else: print("🚀 Starting from scratch.") total_steps = start_step for epoch in range(start_epoch, epochs): model.train() total_loss, correct, total = 0, 0, 0 loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}") for step, (x, y) in loop: x, y = x.to(device), y.to(device) # 🎯 CHANGE 6: Add detailed print statements to observe learning # This is CRUCIAL for debugging underfitting on tiny data. if step % 1 == 0: # Print every step for tiny datasets input_ids_cpu = x[0].cpu().tolist() target_ids_cpu = y[0].cpu().tolist() decoded_input = tokenizer.decode(input_ids_cpu) decoded_target = tokenizer.decode(target_ids_cpu) print(f"\n--- Epoch {epoch+1}, Step {step} ---") print(f"Input (decoded): '{decoded_input}'") print(f"Target (decoded): '{decoded_target}'") logits = model(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() preds = torch.argmax(logits, dim=-1) correct += (preds == y).sum().item() total += y.numel() acc = 100 * correct / total loop.set_postfix(loss=loss.item(), acc=acc) # After optimizer.step(), print predicted output to see if it matches target if step % 1 == 0: predicted_logits_cpu = logits[0, :, :].cpu() # For first example in batch predicted_ids = torch.argmax(predicted_logits_cpu, dim=-1).tolist() decoded_predicted = tokenizer.decode(predicted_ids) print(f"Predicted (decoded): '{decoded_predicted}'") print(f"Current Batch Loss: {loss.item():.4f}") print(f"Current Batch Accuracy: {100 * (preds == y).float().mean().item():.2f}%") # Accuracy for current batch # 🔍 Validate after each epoch val_loss, val_acc = validate(model, val_loader, device) print(f"✅ Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%") # 💾 Save checkpoint torch.save({ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "step": total_steps }, checkpoint_path) torch.save(model.state_dict(), "./trained-mini-gpt/mini-gpt.pth") print("🎉 Training complete.")