DiffBIR-OpenXLab / model /callbacks.py
Dynamatrix's picture
Upload 61 files
3ed3379 verified
from typing import Dict, Any
import os
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
from .mixins import ImageLoggerMixin
__all__ = [
"ModelCheckpoint",
"ImageLogger"
]
class ImageLogger(Callback):
"""
Log images during training or validating.
TODO: Support validating.
"""
def __init__(
self,
log_every_n_steps: int=2000,
max_images_each_step: int=4,
log_images_kwargs: Dict[str, Any]=None
) -> "ImageLogger":
super().__init__()
self.log_every_n_steps = log_every_n_steps
self.max_images_each_step = max_images_each_step
self.log_images_kwargs = log_images_kwargs or dict()
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
assert isinstance(pl_module, ImageLoggerMixin)
@rank_zero_only
def on_train_batch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT,
batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if pl_module.global_step % self.log_every_n_steps == 0:
is_train = pl_module.training
if is_train:
pl_module.freeze()
with torch.no_grad():
# returned images should be: nchw, rgb, [0, 1]
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
# save images
save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train")
os.makedirs(save_dir, exist_ok=True)
for image_key in images:
image = images[image_key].detach().cpu()
N = min(self.max_images_each_step, len(image))
grid = torchvision.utils.make_grid(image[:N], nrow=4)
# chw -> hwc (hw if gray)
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy()
grid = (grid * 255).clip(0, 255).astype(np.uint8)
filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format(
image_key, pl_module.global_step, pl_module.current_epoch, batch_idx
)
path = os.path.join(save_dir, filename)
Image.fromarray(grid).save(path)
if is_train:
pl_module.unfreeze()