tahamajs's picture
Upload ComplexUNet for CIFAR-10 inpainting
1fe5b3c verified
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
if in_channels != out_channels:
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels))
else:
self.shortcut = nn.Identity()
def forward(self, x):
residual = self.shortcut(x)
out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
out = self.conv2(out); out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class ComplexUNet(nn.Module):
def __init__(self, base_channels=96): # Default to the trained architecture
super(ComplexUNet, self).__init__()
c = base_channels
self.pool = nn.MaxPool2d(2, 2)
self.enc1 = ResidualBlock(3, c)
self.enc2 = ResidualBlock(c, c*2)
self.enc3 = ResidualBlock(c*2, c*4)
self.enc4 = ResidualBlock(c*4, c*8)
self.bottleneck = ResidualBlock(c*8, c*16)
self.upconv1 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2)
self.upconv2 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2)
self.upconv3 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2)
self.upconv4 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2)
self.dec_conv1 = ResidualBlock(c*16, c*8)
self.dec_conv2 = ResidualBlock(c*8, c*4)
self.dec_conv3 = ResidualBlock(c*4, c*2)
self.dec_conv4 = ResidualBlock(c*2, c)
self.final_conv = nn.Conv2d(c, 3, kernel_size=1)
def forward(self, x):
e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2)
e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4)
b = self.bottleneck(p4)
d1 = self.upconv1(b); d1 = torch.cat([d1, e4], dim=1); d1 = self.dec_conv1(d1)
d2 = self.upconv2(d1); d2 = torch.cat([d2, e3], dim=1); d2 = self.dec_conv2(d2)
d3 = self.upconv3(d2); d3 = torch.cat([d3, e2], dim=1); d3 = self.dec_conv3(d3)
d4 = self.upconv4(d3); d4 = torch.cat([d4, e1], dim=1); d4 = self.dec_conv4(d4)
out = self.final_conv(d4)
return torch.sigmoid(out)