Spaces:
Build error
Build error
from typing import Dict, Any | |
import os | |
import numpy as np | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.utilities.types import STEP_OUTPUT | |
import torch | |
import torchvision | |
from PIL import Image | |
from pytorch_lightning.callbacks import Callback | |
from pytorch_lightning.utilities.distributed import rank_zero_only | |
from .mixins import ImageLoggerMixin | |
__all__ = [ | |
"ModelCheckpoint", | |
"ImageLogger" | |
] | |
class ImageLogger(Callback): | |
""" | |
Log images during training or validating. | |
TODO: Support validating. | |
""" | |
def __init__( | |
self, | |
log_every_n_steps: int=2000, | |
max_images_each_step: int=4, | |
log_images_kwargs: Dict[str, Any]=None | |
) -> "ImageLogger": | |
super().__init__() | |
self.log_every_n_steps = log_every_n_steps | |
self.max_images_each_step = max_images_each_step | |
self.log_images_kwargs = log_images_kwargs or dict() | |
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | |
assert isinstance(pl_module, ImageLoggerMixin) | |
def on_train_batch_end( | |
self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, | |
batch: Any, batch_idx: int, dataloader_idx: int | |
) -> None: | |
if pl_module.global_step % self.log_every_n_steps == 0: | |
is_train = pl_module.training | |
if is_train: | |
pl_module.freeze() | |
with torch.no_grad(): | |
# returned images should be: nchw, rgb, [0, 1] | |
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs) | |
# save images | |
save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train") | |
os.makedirs(save_dir, exist_ok=True) | |
for image_key in images: | |
image = images[image_key].detach().cpu() | |
N = min(self.max_images_each_step, len(image)) | |
grid = torchvision.utils.make_grid(image[:N], nrow=4) | |
# chw -> hwc (hw if gray) | |
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy() | |
grid = (grid * 255).clip(0, 255).astype(np.uint8) | |
filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format( | |
image_key, pl_module.global_step, pl_module.current_epoch, batch_idx | |
) | |
path = os.path.join(save_dir, filename) | |
Image.fromarray(grid).save(path) | |
if is_train: | |
pl_module.unfreeze() | |