|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(ResidualBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
if in_channels != out_channels: |
|
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels)) |
|
else: |
|
self.shortcut = nn.Identity() |
|
def forward(self, x): |
|
residual = self.shortcut(x) |
|
out = self.conv1(x); out = self.bn1(out); out = self.relu(out) |
|
out = self.conv2(out); out = self.bn2(out) |
|
out += residual |
|
out = self.relu(out) |
|
return out |
|
|
|
class ComplexUNet(nn.Module): |
|
def __init__(self, base_channels=96): |
|
super(ComplexUNet, self).__init__() |
|
c = base_channels |
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.enc1 = ResidualBlock(3, c) |
|
self.enc2 = ResidualBlock(c, c*2) |
|
self.enc3 = ResidualBlock(c*2, c*4) |
|
self.enc4 = ResidualBlock(c*4, c*8) |
|
self.bottleneck = ResidualBlock(c*8, c*16) |
|
self.upconv1 = nn.ConvTranspose2d(c*16, c*8, kernel_size=2, stride=2) |
|
self.upconv2 = nn.ConvTranspose2d(c*8, c*4, kernel_size=2, stride=2) |
|
self.upconv3 = nn.ConvTranspose2d(c*4, c*2, kernel_size=2, stride=2) |
|
self.upconv4 = nn.ConvTranspose2d(c*2, c, kernel_size=2, stride=2) |
|
self.dec_conv1 = ResidualBlock(c*16, c*8) |
|
self.dec_conv2 = ResidualBlock(c*8, c*4) |
|
self.dec_conv3 = ResidualBlock(c*4, c*2) |
|
self.dec_conv4 = ResidualBlock(c*2, c) |
|
self.final_conv = nn.Conv2d(c, 3, kernel_size=1) |
|
def forward(self, x): |
|
e1 = self.enc1(x); p1 = self.pool(e1); e2 = self.enc2(p1); p2 = self.pool(e2) |
|
e3 = self.enc3(p2); p3 = self.pool(e3); e4 = self.enc4(p3); p4 = self.pool(e4) |
|
b = self.bottleneck(p4) |
|
d1 = self.upconv1(b); d1 = torch.cat([d1, e4], dim=1); d1 = self.dec_conv1(d1) |
|
d2 = self.upconv2(d1); d2 = torch.cat([d2, e3], dim=1); d2 = self.dec_conv2(d2) |
|
d3 = self.upconv3(d2); d3 = torch.cat([d3, e2], dim=1); d3 = self.dec_conv3(d3) |
|
d4 = self.upconv4(d3); d4 = torch.cat([d4, e1], dim=1); d4 = self.dec_conv4(d4) |
|
out = self.final_conv(d4) |
|
return torch.sigmoid(out) |
|
|