Stockai / training_utils.py
rmanzo28's picture
Upload 4 files
cd7d7d2 verified
raw
history blame
1 kB
import torch
class EarlyStopping:
"""
Early stops the training if validation loss doesn't improve after a given patience.
"""
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None or val_loss < self.best_loss - self.delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
def get_scheduler(optimizer, scheduler_type='plateau', **kwargs):
if scheduler_type == 'plateau':
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs)
elif scheduler_type == 'step':
return torch.optim.lr_scheduler.StepLR(optimizer, **kwargs)
else:
raise ValueError(f"Unknown scheduler type: {scheduler_type}")