DNAI / dnafiber /trainee.py
ClementP's picture
Upload 55 files
69591a9 verified
from lightning import LightningModule
import segmentation_models_pytorch as smp
from monai.losses.dice import GeneralizedDiceLoss
from monai.losses.cldice import SoftDiceclDiceLoss
from torchmetrics.classification import Dice, JaccardIndex
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import MetricCollection
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
import torch
import torchvision
from dnafiber.metric import DNAFIBERMetric
class Trainee(LightningModule, PyTorchModelHubMixin):
def __init__(
self, learning_rate=0.001, weight_decay=0.0002, num_classes=3, **model_config
):
super().__init__()
self.model_config = model_config
if (
self.model_config.get("arch", None) is None
or self.model_config["arch"] == "maskrcnn"
):
self.model = None
else:
self.model = smp.create_model(classes=3, **self.model_config, dropout=0.2)
self.loss = GeneralizedDiceLoss(to_onehot_y=False, softmax=False)
self.metric = MetricCollection(
{
"dice": Dice(num_classes=num_classes, ignore_index=0),
"jaccard": JaccardIndex(
num_classes=num_classes,
task="multiclass" if num_classes > 2 else "binary",
ignore_index=0,
),
"detection": DNAFIBERMetric(),
}
)
self.weight_decay = weight_decay
self.learning_rate = learning_rate
self.save_hyperparameters()
def forward(self, x):
yhat = self.model(x)
return yhat
def training_step(self, batch, batch_idx):
x, y = batch["image"], batch["mask"]
y = y.clamp(0, 2)
y_hat = self(x)
loss = self.get_loss(y_hat, y)
self.log("train_loss", loss)
return loss
def get_loss(self, y_hat, y):
y_hat = F.softmax(y_hat, dim=1)
y = F.one_hot(y.long(), num_classes=3)
y = y.permute(0, 3, 1, 2).float()
loss = self.loss(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch["image"], batch["mask"]
y = y.clamp(0, 2)
y_hat = self(x)
loss = self.get_loss(y_hat, y)
self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
self.metric.update(y_hat, y)
return y_hat
def on_validation_epoch_end(self):
scores = self.metric.compute()
self.log_dict(scores, sync_dist=True)
self.metric.reset()
def configure_optimizers(self):
optimizer = AdamW(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = CosineAnnealingLR(
optimizer,
T_max=self.trainer.max_epochs, # type: ignore
eta_min=self.learning_rate / 25,
)
scheduler = {
"scheduler": scheduler,
"interval": "epoch",
}
return [optimizer], [scheduler]
class TraineeMaskRCNN(Trainee):
def __init__(self, learning_rate=0.001, weight_decay=0.0002, **model_config):
super().__init__(learning_rate, weight_decay, **model_config)
self.model = torchvision.models.get_model("maskrcnn_resnet50_fpn_v2")
def forward(self, x):
yhat = self.model(x)
return yhat
def training_step(self, batch, batch_idx):
image = batch["image"]
targets = batch["targets"]
loss_dict = self.model(image, targets)
losses = sum(loss for loss in loss_dict.values())
self.log("train_loss", losses, on_step=True, on_epoch=False, sync_dist=True)
return losses
def validation_step(self, batch, batch_idx):
image = batch["image"]
targets = batch["targets"]
predictions = self.model(image)
b = len(predictions)
predicted_masks = []
gt_masks = []
for i in range(b):
scores = predictions[i]["scores"]
masks = predictions[i]["masks"]
good_masks = masks[scores > 0.5]
# Combined into a single mask
good_masks = torch.sum(good_masks, dim=0)
predicted_masks.append(good_masks)
gt_masks.append(targets[i]["masks"].sum(dim=0))
gt_masks = torch.stack(gt_masks).squeeze(1) > 0
predicted_masks = torch.stack(predicted_masks).squeeze(1) > 0
self.metric.update(predicted_masks, gt_masks)
return predictions
def configure_optimizers(self):
optimizer = AdamW(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
scheduler = CosineAnnealingLR(
optimizer,
T_max=self.trainer.max_epochs, # type: ignore
eta_min=self.learning_rate / 25,
)
scheduler = {
"scheduler": scheduler,
"interval": "epoch",
}
return [optimizer], [scheduler]