import torch import torch.nn as nn class MiniGPT(nn.Module): def __init__(self, vocab_size, d_model=1024, n_heads=16, n_layers=24, max_len=512): super().__init__() self.token_embed = nn.Embedding(vocab_size, d_model) self.pos_embed = nn.Embedding(max_len, d_model) # 🎯 CHANGE 1: Set dropout to 0.0 for debugging underfitting on tiny data # This allows the model to memorize the small dataset. encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=0.0, batch_first=False) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) self.ln = nn.LayerNorm(d_model) self.fc_out = nn.Linear(d_model, vocab_size) def generate_causal_mask(self, T, device): # This mask is correct for a TransformerEncoder used causally (True masks future tokens) return torch.triu(torch.ones(T, T, device=device), diagonal=1).bool() def forward(self, input_ids): B, T = input_ids.shape pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0) x = self.token_embed(input_ids) + self.pos_embed(pos) x = x.transpose(0, 1) # [T, B, D] # Causal Mask mask = self.generate_causal_mask(T, input_ids.device) x = self.transformer(x, mask) x = x.transpose(0, 1) # [B, T, D] x = self.ln(x) return self.fc_out(x) def reset_params(self): for layer in self.children(): if hasattr(layer, 'reset_parameters'): layer.reset_parameters()