Create model_architecture.py
Browse files- model_architecture.py +183 -0
model_architecture.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
class Residual(nn.Module):
|
7 |
+
def __init__(self, fn):
|
8 |
+
super().__init__()
|
9 |
+
self.fn = fn
|
10 |
+
|
11 |
+
def forward(self, x, **kwargs):
|
12 |
+
return self.fn(x, **kwargs) + x
|
13 |
+
|
14 |
+
class PreNorm(nn.Module):
|
15 |
+
def __init__(self, dim, fn):
|
16 |
+
super().__init__()
|
17 |
+
self.norm = nn.LayerNorm(dim)
|
18 |
+
self.fn = fn
|
19 |
+
|
20 |
+
def forward(self, x, **kwargs):
|
21 |
+
return self.fn(self.norm(x), **kwargs)
|
22 |
+
|
23 |
+
class FeedForward(nn.Module):
|
24 |
+
def __init__(self, dim, hidden_dim):
|
25 |
+
super().__init__()
|
26 |
+
self.net = nn.Sequential(
|
27 |
+
nn.Linear(dim, hidden_dim),
|
28 |
+
nn.GELU(),
|
29 |
+
nn.Linear(hidden_dim, dim)
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return self.net(x)
|
34 |
+
|
35 |
+
class Attention(nn.Module):
|
36 |
+
def __init__(self, dim, heads=8):
|
37 |
+
super().__init__()
|
38 |
+
self.heads = heads
|
39 |
+
self.scale = dim ** -0.5
|
40 |
+
|
41 |
+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
42 |
+
self.to_out = nn.Linear(dim, dim)
|
43 |
+
|
44 |
+
def forward(self, x, mask = None):
|
45 |
+
b, n, _, h = *x.shape, self.heads
|
46 |
+
qkv = self.to_qkv(x)
|
47 |
+
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
|
48 |
+
|
49 |
+
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
50 |
+
|
51 |
+
if mask is not None:
|
52 |
+
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
53 |
+
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
54 |
+
mask = mask[:, None, :] * mask[:, :, None]
|
55 |
+
dots.masked_fill_(~mask, float('-inf'))
|
56 |
+
del mask
|
57 |
+
|
58 |
+
attn = dots.softmax(dim=-1)
|
59 |
+
|
60 |
+
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
61 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
62 |
+
out = self.to_out(out)
|
63 |
+
return out
|
64 |
+
|
65 |
+
class Transformer(nn.Module):
|
66 |
+
def __init__(self, dim, depth, heads, mlp_dim):
|
67 |
+
super().__init__()
|
68 |
+
self.layers = nn.ModuleList([])
|
69 |
+
for _ in range(depth):
|
70 |
+
self.layers.append(nn.ModuleList([
|
71 |
+
Residual(PreNorm(dim, Attention(dim, heads = heads))),
|
72 |
+
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
|
73 |
+
]))
|
74 |
+
|
75 |
+
def forward(self, x, mask=None):
|
76 |
+
for attn, ff in self.layers:
|
77 |
+
x = attn(x, mask=mask)
|
78 |
+
x = ff(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
class CViT(nn.Module):
|
82 |
+
def __init__(self, image_size=224, patch_size=7, num_classes=2, channels=512,
|
83 |
+
dim=1024, depth=6, heads=8, mlp_dim=2048):
|
84 |
+
super().__init__()
|
85 |
+
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
86 |
+
|
87 |
+
self.features = nn.Sequential(
|
88 |
+
|
89 |
+
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
90 |
+
nn.BatchNorm2d(num_features=32),
|
91 |
+
nn.ReLU(),
|
92 |
+
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
|
93 |
+
nn.BatchNorm2d(num_features=32),
|
94 |
+
nn.ReLU(),
|
95 |
+
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.BatchNorm2d(num_features=32),
|
97 |
+
nn.ReLU(),
|
98 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
99 |
+
|
100 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
101 |
+
nn.BatchNorm2d(num_features=64),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
104 |
+
nn.BatchNorm2d(num_features=64),
|
105 |
+
nn.ReLU(),
|
106 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
107 |
+
nn.BatchNorm2d(num_features=64),
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
110 |
+
|
111 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
112 |
+
nn.BatchNorm2d(num_features=128),
|
113 |
+
nn.ReLU(),
|
114 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
|
115 |
+
nn.BatchNorm2d(num_features=128),
|
116 |
+
nn.ReLU(),
|
117 |
+
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
|
118 |
+
nn.BatchNorm2d(num_features=128),
|
119 |
+
nn.ReLU(),
|
120 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
121 |
+
|
122 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
123 |
+
nn.BatchNorm2d(num_features=256),
|
124 |
+
nn.ReLU(),
|
125 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
126 |
+
nn.BatchNorm2d(num_features=256),
|
127 |
+
nn.ReLU(),
|
128 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
129 |
+
nn.BatchNorm2d(num_features=256),
|
130 |
+
nn.ReLU(),
|
131 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
132 |
+
nn.BatchNorm2d(num_features=256),
|
133 |
+
nn.ReLU(),
|
134 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
135 |
+
|
136 |
+
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
|
137 |
+
nn.BatchNorm2d(num_features=512),
|
138 |
+
nn.ReLU(),
|
139 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
140 |
+
nn.BatchNorm2d(num_features=512),
|
141 |
+
nn.ReLU(),
|
142 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
143 |
+
nn.BatchNorm2d(num_features=512),
|
144 |
+
nn.ReLU(),
|
145 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
146 |
+
nn.BatchNorm2d(num_features=512),
|
147 |
+
nn.ReLU(),
|
148 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
149 |
+
)
|
150 |
+
|
151 |
+
num_patches = (image_size // patch_size) ** 2
|
152 |
+
self.max_sequence_length = num_patches+1
|
153 |
+
patch_dim = channels * patch_size ** 2
|
154 |
+
|
155 |
+
self.patch_size = patch_size
|
156 |
+
|
157 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, self.max_sequence_length, dim))
|
158 |
+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
159 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
160 |
+
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
161 |
+
|
162 |
+
self.to_cls_token = nn.Identity()
|
163 |
+
|
164 |
+
self.mlp_head = nn.Sequential(
|
165 |
+
nn.Linear(dim, mlp_dim),
|
166 |
+
nn.ReLU(),
|
167 |
+
nn.Linear(mlp_dim, num_classes)
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(self, img, mask=None):
|
171 |
+
p = self.patch_size
|
172 |
+
x = self.features(img)
|
173 |
+
y = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
174 |
+
|
175 |
+
y = self.patch_to_embedding(y)
|
176 |
+
cls_tokens = self.cls_token.expand(y.shape[0], -1, -1)
|
177 |
+
x = torch.cat((cls_tokens, y), dim=1)
|
178 |
+
|
179 |
+
x += self.pos_embedding[:, :x.size(1)]
|
180 |
+
x = self.transformer(x, mask)
|
181 |
+
x = self.to_cls_token(x[:, 0])
|
182 |
+
|
183 |
+
return self.mlp_head(x)
|