File size: 1,819 Bytes
69591a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from lightning.pytorch.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
import wandb
class LogPredictionSamplesCallback(Callback):
def __init__(self, wandb_logger, n_images=8):
self.n_images = n_images
self.wandb_logger = wandb_logger
super().__init__()
@rank_zero_only
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx < 1 and trainer.is_global_zero:
n = self.n_images
x = batch["image"][:n].float()
h, w = x.shape[-2:]
y = batch["mask"][:n]
pred = outputs[:n]
pred = pred.argmax(dim=1)
if len(y.shape) == 4:
y = y.squeeze(1)
if len(pred.shape) == 4:
pred = pred.squeeze(1)
y = y.clamp(0, 2)
columns = ["image"]
class_labels = {0: "Background", 1: "Red", 2: "Green"}
data = [
[
wandb.Image(
x_i,
masks={
"Prediction": {
"mask_data": p_i.cpu().numpy(),
"class_labels": class_labels,
},
"Groundtruth": {
"mask_data": y_i.cpu().numpy(),
"class_labels": class_labels,
},
},
)
]
for x_i, y_i, p_i in list(zip(x, y, pred))
]
self.wandb_logger.log_table(
data=data, key=f"Validation Batch {batch_idx}", columns=columns
)
|