#!/usr/bin/env python3 # train_wersa_single_gpu.py - FIXED VERSION # --------------------------------------------------------------------- import os, sys, json, math, shutil, logging, argparse from datetime import datetime import torch from torch.utils.data import DataLoader from tqdm import tqdm from accelerate import Accelerator from accelerate.utils import set_seed import bitsandbytes.optim as bnb_optim from transformers import ( AutoTokenizer, get_linear_schedule_with_warmup, ) from datasets import load_from_disk, Dataset, DatasetDict from wersa import WersaConfig, WersaForCausalLM import numpy as np # --------------------------------------------------------------------- # logging # --------------------------------------------------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # ===================================================================== # D A T A C O L L A T O R # ===================================================================== class TruncatingDataCollator: """ • truncate to max_length • pad to max_length • labels: pad tokens → -100 • force every id to be within the vocabulary """ def __init__(self, tokenizer, max_length: int = 1024): self.max_length = max_length self.pad_id = tokenizer.pad_token_id self.unk_id = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 0 self.vocab_size = len(tokenizer) def _pad(self, seq): return seq + [self.pad_id] * (self.max_length - len(seq)) def _sanitize(self, seq): return [ tok if (0 <= tok < self.vocab_size) or tok == -100 else self.unk_id for tok in seq ] def __call__(self, examples): input_ids, attn_mask, labels = [], [], [] for ex in examples: ids = ex["input_ids"][: self.max_length] mask = ex.get("attention_mask", [1] * len(ids))[: self.max_length] lbls = ex.get("labels", ids.copy())[: self.max_length] # copy to avoid modifying original # Pad sequences ids = self._pad(ids) mask = self._pad(mask) lbls = self._pad(lbls) # Set pad tokens to -100 in labels lbls = [-100 if tok == self.pad_id else tok for tok in lbls] # Sanitize to ensure all tokens are valid ids = self._sanitize(ids) lbls = self._sanitize(lbls) input_ids.append(ids) attn_mask.append(mask) labels.append(lbls) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } # ===================================================================== # A C C U R A C Y M E T R I C # ===================================================================== def compute_accuracy(logits, labels, vocab_size=None): try: # Shift for next token prediction lg, lb = logits[..., :-1, :], labels[..., 1:] mask = lb.ne(-100) if torch.isnan(lg).any() or torch.isinf(lg).any(): logger.warning("NaN / Inf in logits – accuracy set to 0") return 0.0 preds = lg.argmax(-1) if vocab_size: preds = torch.clamp(preds, 0, vocab_size - 1) correct = (preds.eq(lb) & mask).sum().item() total = mask.sum().item() return correct / total if total else 0.0 except RuntimeError as e: if "device-side assert" in str(e): logger.error(f"CUDA device assert caught in accuracy: {e}") return 0.0 raise # ===================================================================== # C H E C K P O I N T ( 2 × R O T A T I N G S L O T S ) # ===================================================================== def save_to_alternating_slots(accelerator, model, tokenizer, slot1_dir, slot2_dir, completed_steps, is_main): """Save checkpoint to alternating slots.""" if not is_main: return logger.info(f"{'='*60}\nSaving checkpoint @ step {completed_steps}\n{'='*60}") # Rotate: slot2 → slot1 if os.path.exists(slot2_dir): if os.path.exists(slot1_dir): shutil.rmtree(slot1_dir) shutil.move(slot2_dir, slot1_dir) os.makedirs(slot2_dir, exist_ok=True) # 1) Accelerator state accelerator.save_state(slot2_dir) # 2) HF model + tokenizer unwrapped = accelerator.unwrap_model(model) unwrapped.save_pretrained( slot2_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model), ) tokenizer.save_pretrained(slot2_dir) # Also save config explicitly unwrapped.config.save_pretrained(slot2_dir) # 3) Metadata info_path = os.path.join(slot2_dir, "training_info.json") json.dump( { "completed_steps": completed_steps, "timestamp": datetime.now().isoformat(), "is_most_recent": True }, open(info_path, "w"), indent=2, ) # Mark older slot as not recent if os.path.exists(slot1_dir): old_info_path = os.path.join(slot1_dir, "training_info.json") if os.path.exists(old_info_path): old_info = json.load(open(old_info_path)) old_info["is_most_recent"] = False json.dump(old_info, open(old_info_path, "w"), indent=2) logger.info("Checkpoint saved ✓\n") def find_latest_checkpoint(output_dir): """Find the latest checkpoint from alternating slots.""" slot1 = os.path.join(output_dir, "slot1") slot2 = os.path.join(output_dir, "slot2") def load_info(path): try: info = json.load(open(os.path.join(path, "training_info.json"))) return int(info.get("completed_steps", 0)) except Exception: return None step2 = load_info(slot2) step1 = load_info(slot1) if step2 is not None: return slot2, step2 if step1 is not None: return slot1, step1 return None, 0 # ===================================================================== # T R A I N # ===================================================================== def main(args): accelerator = Accelerator( mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16", gradient_accumulation_steps=args.gradient_accumulation_steps, ) set_seed(42) is_main = accelerator.is_main_process # ------------------ TOKENIZER ---------------------------------- tok = AutoTokenizer.from_pretrained(args.tokenizer_name) if tok.pad_token is None: tok.pad_token = tok.eos_token # Ensure unk_token exists if tok.unk_token is None: tok.add_special_tokens({"unk_token": ""}) vocab_size = len(tok) # Warn if vocab size is very large if vocab_size > 50000 and is_main: logger.warning( f"Tokenizer has {vocab_size} tokens. Consider using a smaller vocabulary " f"for better efficiency with WERSA architecture." ) # ------------------ DATASET SANITY CHECK ----------------------- raw_ds = load_from_disk(args.dataset_path) if isinstance(raw_ds, DatasetDict): train_split = raw_ds["train"] if "train" in raw_ds else next(iter(raw_ds.values())) else: train_split = raw_ds if is_main: logger.info("Scanning dataset for token ids ≥ vocab size …") def scan_batch(batch): all_ids = [] for seq in batch["input_ids"]: all_ids.extend(seq) max_in_batch = max(all_ids) if all_ids else -1 batch_size = len(batch["input_ids"]) return {"max": [max_in_batch] * batch_size} scan_ds = train_split.map( scan_batch, batched=True, batch_size=1024, num_proc=min(os.cpu_count() or 1, 8), desc="scan-ids", keep_in_memory=False, ) max_seen = max(scan_ds["max"]) if max_seen >= vocab_size: logger.warning( f"Dataset contains token id {max_seen} ≥ vocab size {vocab_size}. " "Those tokens will be replaced by during collation." ) del scan_ds # ------------------ MODEL -------------------------------------- # FIXED: Include WERSA-specific parameters cfg = WersaConfig( vocab_size=vocab_size, pad_token_id=tok.pad_token_id, bos_token_id=tok.bos_token_id if tok.bos_token_id is not None else 0, eos_token_id=tok.eos_token_id if tok.eos_token_id is not None else 1, hidden_size=1600, num_hidden_layers=20, num_attention_heads=32, intermediate_size=6400, # WERSA-specific parameters wersa_decomp_levels=2, # From paper wersa_random_features=1024, # From paper hidden_act="gelu", initializer_range=0.02, layer_norm_eps=1e-5, dropout=0.1, use_cache=False, # Disable KV cache for training ) model = WersaForCausalLM(cfg) # Log model size total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) if is_main: logger.info(f"Model size: {total_params/1e6:.1f}M parameters ({trainable_params/1e6:.1f}M trainable)") if args.gradient_checkpointing: model.gradient_checkpointing_enable() if is_main: logger.info("Gradient checkpointing enabled") # ------------------ DATA LOADER -------------------------------- ds = load_from_disk(args.dataset_path) if isinstance(ds, DatasetDict): ds = ds["train"] if "train" in ds else next(iter(ds.values())) # Remove unnecessary columns cols_to_drop = [c for c in ds.column_names if c not in ("input_ids", "attention_mask", "labels")] if cols_to_drop: ds = ds.remove_columns(cols_to_drop) ds = ds.shuffle(seed=42) collator = TruncatingDataCollator(tok, args.max_length) dl = DataLoader( ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collator, persistent_workers=True, # Keep workers alive between epochs ) steps_per_epoch = len(dl) // args.gradient_accumulation_steps total_steps = min(args.max_steps, steps_per_epoch * args.num_epochs) # ------------------ OPTIMIZER & SCHEDULER ---------------------- # Use different learning rates for different parameter groups no_decay = ["bias", "LayerNorm.weight", "layer_norm"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 0.01, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optim = bnb_optim.AdamW8bit( optimizer_grouped_parameters, lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8 ) # Cosine schedule with warmup sched = get_linear_schedule_with_warmup( optim, num_warmup_steps=min(1000, int(0.1 * total_steps)), num_training_steps=total_steps ) # Prepare with accelerator model, optim, dl, sched = accelerator.prepare(model, optim, dl, sched) # ------------------ RESUME LOGIC ------------------------------- slot1_dir = os.path.join(args.output_dir, "slot1") slot2_dir = os.path.join(args.output_dir, "slot2") completed_steps, start_epoch, skip_batches = 0, 0, 0 ckpt_dir = args.resume_from_checkpoint if not ckpt_dir and args.auto_resume: ckpt_dir, completed_steps = find_latest_checkpoint(args.output_dir) if ckpt_dir and os.path.exists(ckpt_dir): if is_main: logger.info(f"Resuming training from {ckpt_dir}") accelerator.load_state(ckpt_dir) info_path = os.path.join(ckpt_dir, "training_info.json") if os.path.exists(info_path): info = json.load(open(info_path)) completed_steps = info.get("completed_steps", 0) if completed_steps > 0: start_epoch = completed_steps // steps_per_epoch skip_batches = (completed_steps % steps_per_epoch) * args.gradient_accumulation_steps if is_main: logger.info(f"Resuming from step {completed_steps}, epoch {start_epoch}") # ------------------ TRAIN LOOP --------------------------------- running_loss = 0.0 running_acc = 0.0 batches_in_stat = 0 model.train() if is_main: logger.info(f"Starting training for {total_steps} steps") logger.info(f"Steps per epoch: {steps_per_epoch}") for epoch in range(start_epoch, args.num_epochs): epoch_iter = tqdm( enumerate(dl), total=len(dl), disable=not is_main, desc=f"Epoch {epoch + 1}/{args.num_epochs}" ) for step, batch in epoch_iter: # Skip already processed batches when resuming if skip_batches > 0: if step < skip_batches: continue skip_batches = 0 if completed_steps >= total_steps: break with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss logits = outputs.logits # Track metrics running_loss += loss.detach().float() running_acc += compute_accuracy(logits.detach(), batch["labels"], vocab_size) batches_in_stat += 1 # Backward pass accelerator.backward(loss) # Step if we've accumulated enough gradients if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), 1.0) optim.step() sched.step() optim.zero_grad(set_to_none=True) completed_steps += 1 # Log stats if batches_in_stat > 0: avg_loss = running_loss / batches_in_stat avg_acc = running_acc / batches_in_stat ppl = math.exp(avg_loss) if avg_loss < 10 else float("inf") lr = sched.get_last_lr()[0] if is_main: logger.info( f"step {completed_steps}/{total_steps} | " f"loss {avg_loss:.4f} | ppl {ppl:.2f} | " f"acc {avg_acc*100:.2f}% | lr {lr:.2e}" ) epoch_iter.set_postfix( loss=f"{avg_loss:.3f}", ppl=f"{ppl:.1f}", acc=f"{avg_acc*100:.1f}%" ) # Reset stats running_loss = 0.0 running_acc = 0.0 batches_in_stat = 0 # Save checkpoint periodically if completed_steps % args.save_steps == 0: # FIXED: Correct parameter order save_to_alternating_slots( accelerator, model, tok, slot1_dir, slot2_dir, completed_steps, is_main ) if completed_steps >= total_steps: break # ------------------ FINAL SAVE -------------------------------- if is_main and completed_steps % args.save_steps != 0: save_to_alternating_slots( accelerator, model, tok, slot1_dir, slot2_dir, completed_steps, True ) if is_main: logger.info(f"Training finished after {completed_steps} steps ✓") # ===================================================================== # A R G U M E N T S # ===================================================================== if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train WERSA model on a dataset") parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset directory") parser.add_argument("--output_dir", type=str, required=True, help="Directory to save checkpoints") parser.add_argument("--tokenizer_name", type=str, default="gpt2", help="Tokenizer name or path (default: gpt2)") parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length") parser.add_argument("--num_epochs", type=int, default=20, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=4, help="Batch size per GPU") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Gradient accumulation steps") parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate") parser.add_argument("--max_steps", type=int, default=65000, help="Maximum training steps") parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every N steps") parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to resume training from") parser.add_argument("--auto_resume", action="store_true", help="Auto-resume from latest checkpoint") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) main(args)