Image colorization
A UNet architecture, utilizing transfer learning by using a pretrained ResNet-34 as an encoder.
Try the model on Google Colab or Huggingface space.

The model takes a 1x224x224 L tensor as input and outputs 2x224x224 ab channels. The decoder has been trained from scratch. The encoder (ResNet-34) was initially frozen for the decoder to adapt to the task, then it was progressively unfrozen layer by layer. Initial layers were not unfrozen, only deeper layers were fine-tuned. Read various research papers. It took 20+ hours of training on Google Colab and Kaggle T4 GPUs to train the model.
There are no dedicated datasets for image colorisation, hence I curated my own dataset and used it to train the model. The COCO 2017 dataset was filtered to remove grayscale images, heavily filtered images, and other artifacts not suitable for training a natural colorization model. Also the images were center-cropped and resized to 224x224. The dataset can be found here. This repository contains the model weights and the UNet architecture to load the weights into.
Usage
Download the architecture file and model weights
hf_hub_download(
repo_id="ayushshah/imagecolorization",
filename="model.py",
local_dir=".",
local_dir_use_symlinks=False
)
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename="model.safetensors"
)
Make sure the input image(s) are of the size 224x224. Convert them to LAB color space. You can use kornia.
Isolate the L channel and make sure it is in the range [0, 1]. L channel is originally in the range [0, 100].
from model import UNet
from safetensors.torch import load_file
model = UNet().to(DEVICE)
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
model.eval()
with torch.no_grad():
ab_pred = model(L_normalized)
The outputs are in the range [-1, 1]. You can convert the ab channels to their original range using a linear scaling function. Afterwards you can concatenate the original L and the ab channels to get the LAB image.
ab = (ab+1) * 255.0 / 2 - 128.0
ab = torch.clamp(ab, -128, 127)
lab = torch.cat((L, ab), dim=1)
References
- Let there be Color!: Joint End-to-end Learning of Global and Local Image Priors for Automatic Image Colorization with Simultaneous Classification
- Colorful Image Colorization
- Color and Attention for U: Modified Multi Attention U-Net for a Better Image Colorization
- Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution
- Downloads last month
- 1