teamalphabmsit's picture
Upload folder using huggingface_hub
44504f7 verified
import math
from torch import nn
from torch.nn.utils.parametrizations import spectral_norm
def initialize_weights(tensor):
return tensor.uniform_() * math.sqrt(0.25 / (tensor.shape[0] + tensor.shape[1]))
class _RRAutoencoder(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(784, 200)
self.linear_2 = nn.Linear(200, 784)
self.encoder = self.linear_1
self.decoder = self.linear_2
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def clamp(self):
pass
class _NNAutoencoder(_RRAutoencoder):
def __init__(self):
super().__init__()
self.linear_1.bias.data.zero_()
self.linear_2.bias.data.zero_()
self.linear_1.weight = nn.Parameter(
initialize_weights(self.linear_1.weight.data)
)
self.linear_2.weight = nn.Parameter(
initialize_weights(self.linear_2.weight.data)
)
def clamp(self):
self.linear_1.weight.data.clamp_(min=0)
self.linear_2.weight.data.clamp_(min=0)
self.linear_1.bias.data.clamp_(min=0)
self.linear_2.bias.data.clamp_(min=0)
class _PNAutoencoder(_NNAutoencoder):
def clamp(self):
self.linear_1.weight.data.clamp_(min=1e-3)
self.linear_2.weight.data.clamp_(min=1e-3)
self.linear_1.bias.data.clamp_(min=0)
self.linear_2.bias.data.clamp_(min=0)
class _NRAutoencoder(_NNAutoencoder):
def clamp(self):
self.linear_1.weight.data.clamp_(min=0)
self.linear_2.weight.data.clamp_(min=0)
class SigmoidNNAutoencoder(_NNAutoencoder):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(self.linear_1, nn.Sigmoid())
self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid())
class TanhNNAutoencoder(_NNAutoencoder):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(self.linear_1, nn.Tanh())
self.decoder = nn.Sequential(self.linear_2, nn.Tanh())
class TanhPNAutoencoder(_PNAutoencoder):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(self.linear_1, nn.Tanh())
self.decoder = nn.Sequential(self.linear_2, nn.Tanh())
class ReLUNNAutoencoder(_NNAutoencoder):
def __init__(self):
super().__init__()
self.linear_1 = spectral_norm(self.linear_1)
self.linear_2 = spectral_norm(self.linear_2)
self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
self.decoder = nn.Sequential(self.linear_2, nn.ReLU())
def clamp(self):
self.linear_1.parametrizations.weight.original.data.clamp_(min=0)
self.linear_2.parametrizations.weight.original.data.clamp_(min=0)
self.linear_1.bias.data.clamp_(min=0)
self.linear_2.bias.data.clamp_(min=0)
class ReLUPNAutoencoder(_PNAutoencoder):
def __init__(self):
super().__init__()
self.linear_1 = spectral_norm(self.linear_1)
self.linear_2 = spectral_norm(self.linear_2)
self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
self.decoder = nn.Sequential(self.linear_2, nn.ReLU())
def clamp(self):
self.linear_1.parametrizations.weight.original.data.clamp_(min=1e-3)
self.linear_2.parametrizations.weight.original.data.clamp_(min=1e-3)
self.linear_1.bias.data.clamp_(min=0)
self.linear_2.bias.data.clamp_(min=0)
class TanhSwishNNAutoencoder(_NNAutoencoder):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(self.linear_1, nn.Tanh())
self.decoder = nn.Sequential(self.linear_2, nn.SiLU())
class ReLUSigmoidNRAutoencoder(_NRAutoencoder):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid())
class ReLUSigmoidRRAutoencoder(_RRAutoencoder):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid())