mhamza-007 commited on
Commit
c572867
·
verified ·
1 Parent(s): 6fccdf6

Create model_architecture.py

Browse files
Files changed (1) hide show
  1. model_architecture.py +183 -0
model_architecture.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ class Residual(nn.Module):
7
+ def __init__(self, fn):
8
+ super().__init__()
9
+ self.fn = fn
10
+
11
+ def forward(self, x, **kwargs):
12
+ return self.fn(x, **kwargs) + x
13
+
14
+ class PreNorm(nn.Module):
15
+ def __init__(self, dim, fn):
16
+ super().__init__()
17
+ self.norm = nn.LayerNorm(dim)
18
+ self.fn = fn
19
+
20
+ def forward(self, x, **kwargs):
21
+ return self.fn(self.norm(x), **kwargs)
22
+
23
+ class FeedForward(nn.Module):
24
+ def __init__(self, dim, hidden_dim):
25
+ super().__init__()
26
+ self.net = nn.Sequential(
27
+ nn.Linear(dim, hidden_dim),
28
+ nn.GELU(),
29
+ nn.Linear(hidden_dim, dim)
30
+ )
31
+
32
+ def forward(self, x):
33
+ return self.net(x)
34
+
35
+ class Attention(nn.Module):
36
+ def __init__(self, dim, heads=8):
37
+ super().__init__()
38
+ self.heads = heads
39
+ self.scale = dim ** -0.5
40
+
41
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
42
+ self.to_out = nn.Linear(dim, dim)
43
+
44
+ def forward(self, x, mask = None):
45
+ b, n, _, h = *x.shape, self.heads
46
+ qkv = self.to_qkv(x)
47
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
48
+
49
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
50
+
51
+ if mask is not None:
52
+ mask = F.pad(mask.flatten(1), (1, 0), value = True)
53
+ assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
54
+ mask = mask[:, None, :] * mask[:, :, None]
55
+ dots.masked_fill_(~mask, float('-inf'))
56
+ del mask
57
+
58
+ attn = dots.softmax(dim=-1)
59
+
60
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
61
+ out = rearrange(out, 'b h n d -> b n (h d)')
62
+ out = self.to_out(out)
63
+ return out
64
+
65
+ class Transformer(nn.Module):
66
+ def __init__(self, dim, depth, heads, mlp_dim):
67
+ super().__init__()
68
+ self.layers = nn.ModuleList([])
69
+ for _ in range(depth):
70
+ self.layers.append(nn.ModuleList([
71
+ Residual(PreNorm(dim, Attention(dim, heads = heads))),
72
+ Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
73
+ ]))
74
+
75
+ def forward(self, x, mask=None):
76
+ for attn, ff in self.layers:
77
+ x = attn(x, mask=mask)
78
+ x = ff(x)
79
+ return x
80
+
81
+ class CViT(nn.Module):
82
+ def __init__(self, image_size=224, patch_size=7, num_classes=2, channels=512,
83
+ dim=1024, depth=6, heads=8, mlp_dim=2048):
84
+ super().__init__()
85
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
86
+
87
+ self.features = nn.Sequential(
88
+
89
+ nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
90
+ nn.BatchNorm2d(num_features=32),
91
+ nn.ReLU(),
92
+ nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
93
+ nn.BatchNorm2d(num_features=32),
94
+ nn.ReLU(),
95
+ nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.BatchNorm2d(num_features=32),
97
+ nn.ReLU(),
98
+ nn.MaxPool2d(kernel_size=2, stride=2),
99
+
100
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
101
+ nn.BatchNorm2d(num_features=64),
102
+ nn.ReLU(),
103
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
104
+ nn.BatchNorm2d(num_features=64),
105
+ nn.ReLU(),
106
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
107
+ nn.BatchNorm2d(num_features=64),
108
+ nn.ReLU(),
109
+ nn.MaxPool2d(kernel_size=2, stride=2),
110
+
111
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
112
+ nn.BatchNorm2d(num_features=128),
113
+ nn.ReLU(),
114
+ nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
115
+ nn.BatchNorm2d(num_features=128),
116
+ nn.ReLU(),
117
+ nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
118
+ nn.BatchNorm2d(num_features=128),
119
+ nn.ReLU(),
120
+ nn.MaxPool2d(kernel_size=2, stride=2),
121
+
122
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
123
+ nn.BatchNorm2d(num_features=256),
124
+ nn.ReLU(),
125
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
126
+ nn.BatchNorm2d(num_features=256),
127
+ nn.ReLU(),
128
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
129
+ nn.BatchNorm2d(num_features=256),
130
+ nn.ReLU(),
131
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
132
+ nn.BatchNorm2d(num_features=256),
133
+ nn.ReLU(),
134
+ nn.MaxPool2d(kernel_size=2, stride=2),
135
+
136
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
137
+ nn.BatchNorm2d(num_features=512),
138
+ nn.ReLU(),
139
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
140
+ nn.BatchNorm2d(num_features=512),
141
+ nn.ReLU(),
142
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
143
+ nn.BatchNorm2d(num_features=512),
144
+ nn.ReLU(),
145
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
146
+ nn.BatchNorm2d(num_features=512),
147
+ nn.ReLU(),
148
+ nn.MaxPool2d(kernel_size=2, stride=2)
149
+ )
150
+
151
+ num_patches = (image_size // patch_size) ** 2
152
+ self.max_sequence_length = num_patches+1
153
+ patch_dim = channels * patch_size ** 2
154
+
155
+ self.patch_size = patch_size
156
+
157
+ self.pos_embedding = nn.Parameter(torch.randn(1, self.max_sequence_length, dim))
158
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
159
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
160
+ self.transformer = Transformer(dim, depth, heads, mlp_dim)
161
+
162
+ self.to_cls_token = nn.Identity()
163
+
164
+ self.mlp_head = nn.Sequential(
165
+ nn.Linear(dim, mlp_dim),
166
+ nn.ReLU(),
167
+ nn.Linear(mlp_dim, num_classes)
168
+ )
169
+
170
+ def forward(self, img, mask=None):
171
+ p = self.patch_size
172
+ x = self.features(img)
173
+ y = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
174
+
175
+ y = self.patch_to_embedding(y)
176
+ cls_tokens = self.cls_token.expand(y.shape[0], -1, -1)
177
+ x = torch.cat((cls_tokens, y), dim=1)
178
+
179
+ x += self.pos_embedding[:, :x.size(1)]
180
+ x = self.transformer(x, mask)
181
+ x = self.to_cls_token(x[:, 0])
182
+
183
+ return self.mlp_head(x)