File size: 1,819 Bytes
69591a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from lightning.pytorch.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
import wandb


class LogPredictionSamplesCallback(Callback):
    def __init__(self, wandb_logger, n_images=8):
        self.n_images = n_images
        self.wandb_logger = wandb_logger
        super().__init__()

    @rank_zero_only
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx < 1 and trainer.is_global_zero:
            n = self.n_images
            x = batch["image"][:n].float()
            h, w = x.shape[-2:]
            y = batch["mask"][:n]
            pred = outputs[:n]
            pred = pred.argmax(dim=1)

            if len(y.shape) == 4:
                y = y.squeeze(1)
            if len(pred.shape) == 4:
                pred = pred.squeeze(1)
            y = y.clamp(0, 2)
            columns = ["image"]
            class_labels = {0: "Background", 1: "Red", 2: "Green"}

            data = [
                [
                    wandb.Image(
                        x_i,
                        masks={
                            "Prediction": {
                                "mask_data": p_i.cpu().numpy(),
                                "class_labels": class_labels,
                            },
                            "Groundtruth": {
                                "mask_data": y_i.cpu().numpy(),
                                "class_labels": class_labels,
                            },
                        },
                    )
                ]
                for x_i, y_i, p_i in list(zip(x, y, pred))
            ]
            self.wandb_logger.log_table(
                data=data, key=f"Validation Batch {batch_idx}", columns=columns
            )