Spaces:
Runtime error
Runtime error
import math | |
from torch import nn | |
from torch.nn.utils.parametrizations import spectral_norm | |
def initialize_weights(tensor): | |
return tensor.uniform_() * math.sqrt(0.25 / (tensor.shape[0] + tensor.shape[1])) | |
class _RRAutoencoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear_1 = nn.Linear(784, 200) | |
self.linear_2 = nn.Linear(200, 784) | |
self.encoder = self.linear_1 | |
self.decoder = self.linear_2 | |
def forward(self, x): | |
x = self.encoder(x) | |
x = self.decoder(x) | |
return x | |
def clamp(self): | |
pass | |
class _NNAutoencoder(_RRAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.linear_1.bias.data.zero_() | |
self.linear_2.bias.data.zero_() | |
self.linear_1.weight = nn.Parameter( | |
initialize_weights(self.linear_1.weight.data) | |
) | |
self.linear_2.weight = nn.Parameter( | |
initialize_weights(self.linear_2.weight.data) | |
) | |
def clamp(self): | |
self.linear_1.weight.data.clamp_(min=0) | |
self.linear_2.weight.data.clamp_(min=0) | |
self.linear_1.bias.data.clamp_(min=0) | |
self.linear_2.bias.data.clamp_(min=0) | |
class _PNAutoencoder(_NNAutoencoder): | |
def clamp(self): | |
self.linear_1.weight.data.clamp_(min=1e-3) | |
self.linear_2.weight.data.clamp_(min=1e-3) | |
self.linear_1.bias.data.clamp_(min=0) | |
self.linear_2.bias.data.clamp_(min=0) | |
class _NRAutoencoder(_NNAutoencoder): | |
def clamp(self): | |
self.linear_1.weight.data.clamp_(min=0) | |
self.linear_2.weight.data.clamp_(min=0) | |
class SigmoidNNAutoencoder(_NNAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(self.linear_1, nn.Sigmoid()) | |
self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid()) | |
class TanhNNAutoencoder(_NNAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(self.linear_1, nn.Tanh()) | |
self.decoder = nn.Sequential(self.linear_2, nn.Tanh()) | |
class TanhPNAutoencoder(_PNAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(self.linear_1, nn.Tanh()) | |
self.decoder = nn.Sequential(self.linear_2, nn.Tanh()) | |
class ReLUNNAutoencoder(_NNAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.linear_1 = spectral_norm(self.linear_1) | |
self.linear_2 = spectral_norm(self.linear_2) | |
self.encoder = nn.Sequential(self.linear_1, nn.ReLU()) | |
self.decoder = nn.Sequential(self.linear_2, nn.ReLU()) | |
def clamp(self): | |
self.linear_1.parametrizations.weight.original.data.clamp_(min=0) | |
self.linear_2.parametrizations.weight.original.data.clamp_(min=0) | |
self.linear_1.bias.data.clamp_(min=0) | |
self.linear_2.bias.data.clamp_(min=0) | |
class ReLUPNAutoencoder(_PNAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.linear_1 = spectral_norm(self.linear_1) | |
self.linear_2 = spectral_norm(self.linear_2) | |
self.encoder = nn.Sequential(self.linear_1, nn.ReLU()) | |
self.decoder = nn.Sequential(self.linear_2, nn.ReLU()) | |
def clamp(self): | |
self.linear_1.parametrizations.weight.original.data.clamp_(min=1e-3) | |
self.linear_2.parametrizations.weight.original.data.clamp_(min=1e-3) | |
self.linear_1.bias.data.clamp_(min=0) | |
self.linear_2.bias.data.clamp_(min=0) | |
class TanhSwishNNAutoencoder(_NNAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(self.linear_1, nn.Tanh()) | |
self.decoder = nn.Sequential(self.linear_2, nn.SiLU()) | |
class ReLUSigmoidNRAutoencoder(_NRAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(self.linear_1, nn.ReLU()) | |
self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid()) | |
class ReLUSigmoidRRAutoencoder(_RRAutoencoder): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(self.linear_1, nn.ReLU()) | |
self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid()) | |