tahamajs commited on
Commit
1fe5b3c
·
verified ·
1 Parent(s): 847a313

Upload ComplexUNet for CIFAR-10 inpainting

Browse files
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ license: mit
4
+ language: en
5
+ library_name: pytorch
6
+ tags:
7
+ - image-inpainting
8
+ - computer-vision
9
+ - pytorch
10
+ - unet
11
+ - cifar-10
12
+ datasets:
13
+ - cifar10
14
+ ---
15
+
16
+ # U-Net for Image Inpainting on CIFAR-10
17
+
18
+ This repository contains a PyTorch implementation of a deep U-Net with Residual Blocks, trained to perform image inpainting on the CIFAR-10 dataset. The model takes an image with a masked (blacked-out) region and reconstructs the missing part.
19
+
20
+ ## Model Description
21
+
22
+ The model is a `ComplexUNet` architecture, a variant of the standard U-Net. It features:
23
+ - **Deeper Architecture**: 4 downsampling and 4 upsampling stages.
24
+ - **Residual Blocks**: Each stage uses residual blocks instead of simple convolutional layers.
25
+ - **Increased Width**: The model was trained with `base_channels=96`.
26
+ - **Total Parameters**: 73,148,259
27
+
28
+ ## How to Use
29
+
30
+ First, install the required libraries:
31
+ ```bash
32
+ pip install torch torchvision numpy Pillow
33
+ ```
34
+
35
+ Then, you can load the model and perform inpainting on an image tensor.
36
+
37
+ ```python
38
+ import torch
39
+ from torchvision import transforms as T
40
+ from PIL import Image
41
+ from model import ComplexUNet # Import the class from model.py
42
+
43
+ # --- Setup ---
44
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ # Download the .pth file from the 'Files and versions' tab of this repo
46
+ MODEL_PATH = "inpainting_model_larger.pth"
47
+
48
+ # --- Load Model ---
49
+ model = ComplexUNet(base_channels=96)
50
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
51
+ model.to(DEVICE)
52
+ model.eval()
53
+
54
+ # --- Load and Preprocess Image ---
55
+ # image = Image.open("your_image.png").convert("RGB")
56
+ # For demonstration, let's create a dummy tensor
57
+ transform = T.Compose([T.Resize((32, 32)), T.ToTensor()])
58
+ # image_tensor = transform(image)
59
+ image_tensor = torch.rand(3, 32, 32)
60
+
61
+ # --- Create a Mask ---
62
+ masked_tensor = image_tensor.clone()
63
+ masked_tensor[:, 8:24, 8:24] = 0 # Example mask in the center
64
+
65
+ # --- Perform Inpainting ---
66
+ with torch.no_grad():
67
+ input_tensor = masked_tensor.unsqueeze(0).to(DEVICE)
68
+ reconstructed_tensor = model(input_tensor).squeeze(0).cpu()
69
+
70
+ # 'reconstructed_tensor' now holds the inpainted image.
71
+ from torchvision.transforms.functional import to_pil_image
72
+ reconstructed_image = to_pil_image(reconstructed_tensor)
73
+ reconstructed_image.save("reconstructed_image.png")
74
+ print("Saved reconstructed_image.png")
75
+ ```
76
+
77
+ ## Training Data
78
+
79
+ The model was trained on the **CIFAR-10** dataset.
80
+ - **Preprocessing**: Images were used at their original **32x32 pixels** resolution.
81
+ - **Augmentation**: For each training image, a random rectangular mask was applied.
82
+
83
+ ## Training Procedure
84
+
85
+ - **Framework**: PyTorch
86
+ - **Optimizer**: Adam
87
+ - **Learning Rate**: 0.001
88
+ - **Epochs**: 50
89
+ - **Batch Size**: 128
90
+ - **Loss Function**: Mean Squared Error (MSE)
91
+
92
+ ## Evaluation
93
+
94
+ Evaluation metrics were not saved by the training script. To get PSNR and SSIM, please run the `evaluate_model` function from the training script.
inpainting_model_larger.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f66c40cc5deafd5486dca299907058bf7ade095d9e1330b6767df1490485280
3
+ size 292031834
inpainting_model_multi_gpu.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f41891b30399ed3d9bc498b087b5c503b846a175837144ac8adc717068638408
3
+ size 30855650
model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ResidualBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(ResidualBlock, self).__init__()
8
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
9
+ self.bn1 = nn.BatchNorm2d(out_channels)
10
+ self.relu = nn.ReLU(inplace=True)
11
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
12
+ self.bn2 = nn.BatchNorm2d(out_channels)
13
+ if in_channels != out_channels:
14
+ self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels))
15
+ else:
16
+ self.shortcut = nn.Identity()
17
+ def forward(self, x):
18
+ residual = self.shortcut(x)
19
+ out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
20
+ out = self.conv2(out); out = self.bn2(out)
21
+ out += residual
22
+ out = self.relu(out)
23
+ return out
24
+
25
+ class ComplexUNet(nn.Module):
26
+ def __init__(self, base_channels=96): # Default to the trained architecture
27
+ super(ComplexUNet, self).__init__()
28
+ c = base_channels
29
+ self.pool = nn.MaxPool2d(2, 2)
30
+ self.enc1 = ResidualBlock(3, c)
31
+ self.enc2 = ResidualBlock(c, c*2)
32
+ self.enc3 = ResidualBlock(c*2, c*4)
33
+ self.enc4 = ResidualBlock(c*4, c*8)
34
+ self.bottleneck = ResidualBlock(c*8, c*16)
35
+ self.upconv1 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2)
36
+ self.upconv2 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2)
37
+ self.upconv3 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2)
38
+ self.upconv4 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2)
39
+ self.dec_conv1 = ResidualBlock(c*16, c*8)
40
+ self.dec_conv2 = ResidualBlock(c*8, c*4)
41
+ self.dec_conv3 = ResidualBlock(c*4, c*2)
42
+ self.dec_conv4 = ResidualBlock(c*2, c)
43
+ self.final_conv = nn.Conv2d(c, 3, kernel_size=1)
44
+ def forward(self, x):
45
+ e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2)
46
+ e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4)
47
+ b = self.bottleneck(p4)
48
+ d1 = self.upconv1(b); d1 = torch.cat([d1, e4], dim=1); d1 = self.dec_conv1(d1)
49
+ d2 = self.upconv2(d1); d2 = torch.cat([d2, e3], dim=1); d2 = self.dec_conv2(d2)
50
+ d3 = self.upconv3(d2); d3 = torch.cat([d3, e2], dim=1); d3 = self.dec_conv3(d3)
51
+ d4 = self.upconv4(d3); d4 = torch.cat([d4, e1], dim=1); d4 = self.dec_conv4(d4)
52
+ out = self.final_conv(d4)
53
+ return torch.sigmoid(out)