Upload ComplexUNet for CIFAR-10 inpainting
Browse files- README.md +94 -0
- inpainting_model_larger.pth +3 -0
- inpainting_model_multi_gpu.pth +3 -0
- model.py +53 -0
README.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
---
|
3 |
+
license: mit
|
4 |
+
language: en
|
5 |
+
library_name: pytorch
|
6 |
+
tags:
|
7 |
+
- image-inpainting
|
8 |
+
- computer-vision
|
9 |
+
- pytorch
|
10 |
+
- unet
|
11 |
+
- cifar-10
|
12 |
+
datasets:
|
13 |
+
- cifar10
|
14 |
+
---
|
15 |
+
|
16 |
+
# U-Net for Image Inpainting on CIFAR-10
|
17 |
+
|
18 |
+
This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained to perform image inpainting on the CIFAR-10 dataset. The model takes an image with a masked (blacked-out) region and reconstructs the missing part.
|
19 |
+
|
20 |
+
## Model Description
|
21 |
+
|
22 |
+
The model is a `ComplexUNet` architecture, a variant of the standard U-Net. It features:
|
23 |
+
- **Deeper Architecture**: 4 downsampling and 4 upsampling stages.
|
24 |
+
- **Residual Blocks**: Each stage uses residual blocks instead of simple convolutional layers.
|
25 |
+
- **Increased Width**: The model was trained with `base_channels=96`.
|
26 |
+
- **Total Parameters**: 73,148,259
|
27 |
+
|
28 |
+
## How to Use
|
29 |
+
|
30 |
+
First, install the required libraries:
|
31 |
+
```bash
|
32 |
+
pip install torch torchvision numpy Pillow
|
33 |
+
```
|
34 |
+
|
35 |
+
Then, you can load the model and perform inpainting on an image tensor.
|
36 |
+
|
37 |
+
```python
|
38 |
+
import torch
|
39 |
+
from torchvision import transforms as T
|
40 |
+
from PIL import Image
|
41 |
+
from model import ComplexUNet # Import the class from model.py
|
42 |
+
|
43 |
+
# --- Setup ---
|
44 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
# Download the .pth file from the 'Files and versions' tab of this repo
|
46 |
+
MODEL_PATH = "inpainting_model_larger.pth"
|
47 |
+
|
48 |
+
# --- Load Model ---
|
49 |
+
model = ComplexUNet(base_channels=96)
|
50 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
51 |
+
model.to(DEVICE)
|
52 |
+
model.eval()
|
53 |
+
|
54 |
+
# --- Load and Preprocess Image ---
|
55 |
+
# image = Image.open("your_image.png").convert("RGB")
|
56 |
+
# For demonstration, let's create a dummy tensor
|
57 |
+
transform = T.Compose([T.Resize((32, 32)), T.ToTensor()])
|
58 |
+
# image_tensor = transform(image)
|
59 |
+
image_tensor = torch.rand(3, 32, 32)
|
60 |
+
|
61 |
+
# --- Create a Mask ---
|
62 |
+
masked_tensor = image_tensor.clone()
|
63 |
+
masked_tensor[:, 8:24, 8:24] = 0 # Example mask in the center
|
64 |
+
|
65 |
+
# --- Perform Inpainting ---
|
66 |
+
with torch.no_grad():
|
67 |
+
input_tensor = masked_tensor.unsqueeze(0).to(DEVICE)
|
68 |
+
reconstructed_tensor = model(input_tensor).squeeze(0).cpu()
|
69 |
+
|
70 |
+
# 'reconstructed_tensor' now holds the inpainted image.
|
71 |
+
from torchvision.transforms.functional import to_pil_image
|
72 |
+
reconstructed_image = to_pil_image(reconstructed_tensor)
|
73 |
+
reconstructed_image.save("reconstructed_image.png")
|
74 |
+
print("Saved reconstructed_image.png")
|
75 |
+
```
|
76 |
+
|
77 |
+
## Training Data
|
78 |
+
|
79 |
+
The model was trained on the **CIFAR-10** dataset.
|
80 |
+
- **Preprocessing**: Images were used at their original **32x32 pixels** resolution.
|
81 |
+
- **Augmentation**: For each training image, a random rectangular mask was applied.
|
82 |
+
|
83 |
+
## Training Procedure
|
84 |
+
|
85 |
+
- **Framework**: PyTorch
|
86 |
+
- **Optimizer**: Adam
|
87 |
+
- **Learning Rate**: 0.001
|
88 |
+
- **Epochs**: 50
|
89 |
+
- **Batch Size**: 128
|
90 |
+
- **Loss Function**: Mean Squared Error (MSE)
|
91 |
+
|
92 |
+
## Evaluation
|
93 |
+
|
94 |
+
Evaluation metrics were not saved by the training script. To get PSNR and SSIM, please run the `evaluate_model` function from the training script.
|
inpainting_model_larger.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f66c40cc5deafd5486dca299907058bf7ade095d9e1330b6767df1490485280
|
3 |
+
size 292031834
|
inpainting_model_multi_gpu.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f41891b30399ed3d9bc498b087b5c503b846a175837144ac8adc717068638408
|
3 |
+
size 30855650
|
model.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class ResidualBlock(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels):
|
7 |
+
super(ResidualBlock, self).__init__()
|
8 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
9 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
10 |
+
self.relu = nn.ReLU(inplace=True)
|
11 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
|
12 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
13 |
+
if in_channels != out_channels:
|
14 |
+
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels))
|
15 |
+
else:
|
16 |
+
self.shortcut = nn.Identity()
|
17 |
+
def forward(self, x):
|
18 |
+
residual = self.shortcut(x)
|
19 |
+
out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
|
20 |
+
out = self.conv2(out); out = self.bn2(out)
|
21 |
+
out += residual
|
22 |
+
out = self.relu(out)
|
23 |
+
return out
|
24 |
+
|
25 |
+
class ComplexUNet(nn.Module):
|
26 |
+
def __init__(self, base_channels=96): # Default to the trained architecture
|
27 |
+
super(ComplexUNet, self).__init__()
|
28 |
+
c = base_channels
|
29 |
+
self.pool = nn.MaxPool2d(2, 2)
|
30 |
+
self.enc1 = ResidualBlock(3, c)
|
31 |
+
self.enc2 = ResidualBlock(c, c*2)
|
32 |
+
self.enc3 = ResidualBlock(c*2, c*4)
|
33 |
+
self.enc4 = ResidualBlock(c*4, c*8)
|
34 |
+
self.bottleneck = ResidualBlock(c*8, c*16)
|
35 |
+
self.upconv1 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2)
|
36 |
+
self.upconv2 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2)
|
37 |
+
self.upconv3 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2)
|
38 |
+
self.upconv4 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2)
|
39 |
+
self.dec_conv1 = ResidualBlock(c*16, c*8)
|
40 |
+
self.dec_conv2 = ResidualBlock(c*8, c*4)
|
41 |
+
self.dec_conv3 = ResidualBlock(c*4, c*2)
|
42 |
+
self.dec_conv4 = ResidualBlock(c*2, c)
|
43 |
+
self.final_conv = nn.Conv2d(c, 3, kernel_size=1)
|
44 |
+
def forward(self, x):
|
45 |
+
e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2)
|
46 |
+
e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4)
|
47 |
+
b = self.bottleneck(p4)
|
48 |
+
d1 = self.upconv1(b); d1 = torch.cat([d1, e4], dim=1); d1 = self.dec_conv1(d1)
|
49 |
+
d2 = self.upconv2(d1); d2 = torch.cat([d2, e3], dim=1); d2 = self.dec_conv2(d2)
|
50 |
+
d3 = self.upconv3(d2); d3 = torch.cat([d3, e2], dim=1); d3 = self.dec_conv3(d3)
|
51 |
+
d4 = self.upconv4(d3); d4 = torch.cat([d4, e1], dim=1); d4 = self.dec_conv4(d4)
|
52 |
+
out = self.final_conv(d4)
|
53 |
+
return torch.sigmoid(out)
|