|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
class UpsampleConv(nn.Module): |
|
|
def __init__(self, in_ch: int, out_ch: int, kernel: int, scale: int = 2): |
|
|
super().__init__() |
|
|
self.scale: int = scale |
|
|
|
|
|
pad = kernel // 2 |
|
|
self.upsample_conv = nn.ConvTranspose2d( |
|
|
in_ch, out_ch, |
|
|
kernel_size=kernel, |
|
|
stride=scale, |
|
|
padding=pad, |
|
|
output_padding=scale-1 |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.upsample_conv(x) |
|
|
|
|
|
class StyleTransferNet(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = ConvLayer(3, 64, kernel=9, stride=2) |
|
|
self.norm1 = nn.InstanceNorm2d(64, affine=True) |
|
|
|
|
|
self.conv2 = ConvLayer(64, 256, kernel=3, stride=2) |
|
|
self.norm2 = nn.InstanceNorm2d(256, affine=True) |
|
|
|
|
|
|
|
|
self.res_blocks = nn.ModuleList([ |
|
|
ResidualBlock(256) for _ in range(5) |
|
|
]) |
|
|
|
|
|
|
|
|
self.up1 = UpsampleConv(256, 64, kernel=3, scale=2) |
|
|
self.norm3 = nn.InstanceNorm2d(64, affine=True) |
|
|
|
|
|
self.up2 = UpsampleConv(64, 32, kernel=3, scale=2) |
|
|
self.norm4 = nn.InstanceNorm2d(32, affine=True) |
|
|
|
|
|
|
|
|
self.final_conv = ConvLayer(32, 3, kernel=9, stride=1) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
enc1 = F.relu(self.norm1(self.conv1(x))) |
|
|
enc2 = F.relu(self.norm2(self.conv2(enc1))) |
|
|
|
|
|
|
|
|
res = enc2 |
|
|
for res_block in self.res_blocks: |
|
|
res = res_block(res) |
|
|
|
|
|
|
|
|
dec1 = F.relu(self.norm3(self.up1(res))) |
|
|
dec2 = F.relu(self.norm4(self.up2(dec1))) |
|
|
|
|
|
|
|
|
output = self.final_conv(dec2) |
|
|
return output |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
|
def __init__(self, in_ch: int, out_ch: int, kernel: int, stride=1): |
|
|
super().__init__() |
|
|
pad = kernel // 2 |
|
|
self.reflection_pad = nn.ReflectionPad2d(pad) |
|
|
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=stride) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.conv(self.reflection_pad(x)) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, channels: int): |
|
|
super().__init__() |
|
|
self.conv1 = ConvLayer(channels, channels, kernel=3) |
|
|
self.in1 = nn.InstanceNorm2d(channels, affine=True) |
|
|
self.conv2 = ConvLayer(channels, channels, kernel=3) |
|
|
self.in2 = nn.InstanceNorm2d(channels, affine=True) |
|
|
self.dropout = nn.Dropout2d(0.1) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
y = F.relu(self.in1(self.conv1(x))) |
|
|
y = self.dropout(y) |
|
|
y = self.in2(self.conv2(y)) |
|
|
return x + y |