import torch class EarlyStopping: """ Early stops the training if validation loss doesn't improve after a given patience. """ def __init__(self, patience=5, delta=0): self.patience = patience self.delta = delta self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss): if self.best_loss is None or val_loss < self.best_loss - self.delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True def get_scheduler(optimizer, scheduler_type='plateau', **kwargs): if scheduler_type == 'plateau': return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs) elif scheduler_type == 'step': return torch.optim.lr_scheduler.StepLR(optimizer, **kwargs) else: raise ValueError(f"Unknown scheduler type: {scheduler_type}")