File size: 4,149 Bytes
44504f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import math
from torch import nn
from torch.nn.utils.parametrizations import spectral_norm


def initialize_weights(tensor):
    return tensor.uniform_() * math.sqrt(0.25 / (tensor.shape[0] + tensor.shape[1]))


class _RRAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(784, 200)
        self.linear_2 = nn.Linear(200, 784)
        self.encoder = self.linear_1
        self.decoder = self.linear_2

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x

    def clamp(self):
        pass


class _NNAutoencoder(_RRAutoencoder):
    def __init__(self):
        super().__init__()
        self.linear_1.bias.data.zero_()
        self.linear_2.bias.data.zero_()
        self.linear_1.weight = nn.Parameter(
            initialize_weights(self.linear_1.weight.data)
        )
        self.linear_2.weight = nn.Parameter(
            initialize_weights(self.linear_2.weight.data)
        )

    def clamp(self):
        self.linear_1.weight.data.clamp_(min=0)
        self.linear_2.weight.data.clamp_(min=0)
        self.linear_1.bias.data.clamp_(min=0)
        self.linear_2.bias.data.clamp_(min=0)


class _PNAutoencoder(_NNAutoencoder):
    def clamp(self):
        self.linear_1.weight.data.clamp_(min=1e-3)
        self.linear_2.weight.data.clamp_(min=1e-3)
        self.linear_1.bias.data.clamp_(min=0)
        self.linear_2.bias.data.clamp_(min=0)


class _NRAutoencoder(_NNAutoencoder):
    def clamp(self):
        self.linear_1.weight.data.clamp_(min=0)
        self.linear_2.weight.data.clamp_(min=0)


class SigmoidNNAutoencoder(_NNAutoencoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(self.linear_1, nn.Sigmoid())
        self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid())


class TanhNNAutoencoder(_NNAutoencoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(self.linear_1, nn.Tanh())
        self.decoder = nn.Sequential(self.linear_2, nn.Tanh())


class TanhPNAutoencoder(_PNAutoencoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(self.linear_1, nn.Tanh())
        self.decoder = nn.Sequential(self.linear_2, nn.Tanh())


class ReLUNNAutoencoder(_NNAutoencoder):
    def __init__(self):
        super().__init__()
        self.linear_1 = spectral_norm(self.linear_1)
        self.linear_2 = spectral_norm(self.linear_2)
        self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
        self.decoder = nn.Sequential(self.linear_2, nn.ReLU())

    def clamp(self):
        self.linear_1.parametrizations.weight.original.data.clamp_(min=0)
        self.linear_2.parametrizations.weight.original.data.clamp_(min=0)
        self.linear_1.bias.data.clamp_(min=0)
        self.linear_2.bias.data.clamp_(min=0)


class ReLUPNAutoencoder(_PNAutoencoder):
    def __init__(self):
        super().__init__()
        self.linear_1 = spectral_norm(self.linear_1)
        self.linear_2 = spectral_norm(self.linear_2)
        self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
        self.decoder = nn.Sequential(self.linear_2, nn.ReLU())

    def clamp(self):
        self.linear_1.parametrizations.weight.original.data.clamp_(min=1e-3)
        self.linear_2.parametrizations.weight.original.data.clamp_(min=1e-3)
        self.linear_1.bias.data.clamp_(min=0)
        self.linear_2.bias.data.clamp_(min=0)


class TanhSwishNNAutoencoder(_NNAutoencoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(self.linear_1, nn.Tanh())
        self.decoder = nn.Sequential(self.linear_2, nn.SiLU())


class ReLUSigmoidNRAutoencoder(_NRAutoencoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
        self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid())


class ReLUSigmoidRRAutoencoder(_RRAutoencoder):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(self.linear_1, nn.ReLU())
        self.decoder = nn.Sequential(self.linear_2, nn.Sigmoid())