|
from lightning.pytorch.callbacks import Callback
|
|
from pytorch_lightning.utilities import rank_zero_only
|
|
import wandb
|
|
|
|
|
|
class LogPredictionSamplesCallback(Callback):
|
|
def __init__(self, wandb_logger, n_images=8):
|
|
self.n_images = n_images
|
|
self.wandb_logger = wandb_logger
|
|
super().__init__()
|
|
|
|
@rank_zero_only
|
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
if batch_idx < 1 and trainer.is_global_zero:
|
|
n = self.n_images
|
|
x = batch["image"][:n].float()
|
|
h, w = x.shape[-2:]
|
|
y = batch["mask"][:n]
|
|
pred = outputs[:n]
|
|
pred = pred.argmax(dim=1)
|
|
|
|
if len(y.shape) == 4:
|
|
y = y.squeeze(1)
|
|
if len(pred.shape) == 4:
|
|
pred = pred.squeeze(1)
|
|
y = y.clamp(0, 2)
|
|
columns = ["image"]
|
|
class_labels = {0: "Background", 1: "Red", 2: "Green"}
|
|
|
|
data = [
|
|
[
|
|
wandb.Image(
|
|
x_i,
|
|
masks={
|
|
"Prediction": {
|
|
"mask_data": p_i.cpu().numpy(),
|
|
"class_labels": class_labels,
|
|
},
|
|
"Groundtruth": {
|
|
"mask_data": y_i.cpu().numpy(),
|
|
"class_labels": class_labels,
|
|
},
|
|
},
|
|
)
|
|
]
|
|
for x_i, y_i, p_i in list(zip(x, y, pred))
|
|
]
|
|
self.wandb_logger.log_table(
|
|
data=data, key=f"Validation Batch {batch_idx}", columns=columns
|
|
)
|
|
|