Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
·
e932abd
1
Parent(s):
57273ba
noam lr scheduler, shared weight between embedding and generator
Browse files- data_preprocessing.py +6 -4
- model.py +26 -67
- train.py +7 -22
- utils.py +76 -7
data_preprocessing.py
CHANGED
@@ -115,8 +115,10 @@ class RandomizeImageTransform(object):
|
|
115 |
"""Standardize image and randomly augment"""
|
116 |
|
117 |
def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT, random_magnitude=5):
|
|
|
|
|
118 |
self.transform = T.Compose((
|
119 |
-
T.ColorJitter(brightness=
|
120 |
T.Resize(height),
|
121 |
T.Grayscale(),
|
122 |
T.functional.invert,
|
@@ -184,12 +186,12 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
184 |
)
|
185 |
self.val_dataset = TexImageDataset(
|
186 |
root_dir=DATA_DIR,
|
187 |
-
image_transform=
|
188 |
tex_transform=ExtractEquationFromTexTransform()
|
189 |
)
|
190 |
self.test_dataset = TexImageDataset(
|
191 |
root_dir=DATA_DIR,
|
192 |
-
image_transform=
|
193 |
tex_transform=ExtractEquationFromTexTransform()
|
194 |
)
|
195 |
train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
|
@@ -215,7 +217,7 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
215 |
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
216 |
|
217 |
def val_dataloader(self):
|
218 |
-
return DataLoader(self.val_dataset, batch_size=self.batch_size,
|
219 |
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
220 |
|
221 |
def test_dataloader(self):
|
|
|
115 |
"""Standardize image and randomly augment"""
|
116 |
|
117 |
def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT, random_magnitude=5):
|
118 |
+
assert random_magnitude > 0
|
119 |
+
eps = 0.01
|
120 |
self.transform = T.Compose((
|
121 |
+
T.ColorJitter(brightness=((1 - eps) / (random_magnitude + eps), 1 - eps)),
|
122 |
T.Resize(height),
|
123 |
T.Grayscale(),
|
124 |
T.functional.invert,
|
|
|
186 |
)
|
187 |
self.val_dataset = TexImageDataset(
|
188 |
root_dir=DATA_DIR,
|
189 |
+
image_transform=RandomizeImageTransform(),
|
190 |
tex_transform=ExtractEquationFromTexTransform()
|
191 |
)
|
192 |
self.test_dataset = TexImageDataset(
|
193 |
root_dir=DATA_DIR,
|
194 |
+
image_transform=RandomizeImageTransform(),
|
195 |
tex_transform=ExtractEquationFromTexTransform()
|
196 |
)
|
197 |
train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
|
|
|
217 |
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
218 |
|
219 |
def val_dataloader(self):
|
220 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
|
221 |
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
222 |
|
223 |
def test_dataloader(self):
|
model.py
CHANGED
@@ -41,7 +41,7 @@ class AddPositionalEncoding(nn.Module):
|
|
41 |
class ImageEmbedding(nn.Module):
|
42 |
"""Reshape image into patches and project into given dimension"""
|
43 |
|
44 |
-
def __init__(self, d_model, input_width, input_height, patch_size
|
45 |
super().__init__()
|
46 |
assert input_width % patch_size == 0 and input_height % patch_size == 0, \
|
47 |
"Cannot split image in patches"
|
@@ -64,7 +64,7 @@ class ImageEmbedding(nn.Module):
|
|
64 |
|
65 |
|
66 |
class TexEmbedding(nn.Module):
|
67 |
-
def __init__(self, d_model: int, vocab_size: int, dropout: float
|
68 |
super().__init__()
|
69 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
70 |
self.add_positional_encoding = AddPositionalEncoding(d_model)
|
@@ -85,7 +85,7 @@ class ImageEncoder(nn.Module):
|
|
85 |
|
86 |
def __init__(self, image_width, image_height, d_model, num_layers=8):
|
87 |
super().__init__()
|
88 |
-
image_embedding = ImageEmbedding(d_model, image_width, image_height)
|
89 |
encoder_layer = nn.TransformerEncoderLayer(
|
90 |
d_model=d_model,
|
91 |
nhead=8,
|
@@ -111,7 +111,6 @@ class Transformer(pl.LightningModule):
|
|
111 |
pad_idx: int,
|
112 |
dim_feedforward: int = 512,
|
113 |
dropout: float = .1,
|
114 |
-
learning_rate: float = 1e-3
|
115 |
):
|
116 |
super().__init__()
|
117 |
|
@@ -126,11 +125,13 @@ class Transformer(pl.LightningModule):
|
|
126 |
if p.dim() > 1:
|
127 |
nn.init.xavier_uniform_(p)
|
128 |
|
129 |
-
self.
|
130 |
-
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
|
131 |
self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
|
132 |
-
self.
|
133 |
-
|
|
|
|
|
134 |
self.save_hyperparameters()
|
135 |
|
136 |
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
|
@@ -147,16 +148,10 @@ class Transformer(pl.LightningModule):
|
|
147 |
tgt_input = tgt[:, :-1]
|
148 |
tgt_output = tgt[:, 1:]
|
149 |
src_mask = None
|
150 |
-
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
151 |
-
torch.ByteTensor.dtype)
|
152 |
memory_mask = None
|
153 |
src_padding_mask = None
|
154 |
-
tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
155 |
-
tgt_padding_mask = tgt_padding_mask.masked_fill(
|
156 |
-
tgt_padding_mask == 0, float('-inf')
|
157 |
-
).masked_fill(
|
158 |
-
tgt_padding_mask == 1, 0
|
159 |
-
)
|
160 |
|
161 |
outs = self(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
162 |
loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
|
@@ -179,58 +174,22 @@ class Transformer(pl.LightningModule):
|
|
179 |
return loss
|
180 |
|
181 |
def configure_optimizers(self):
|
182 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=
|
183 |
-
scheduler = torch.optim.lr_scheduler.
|
184 |
return [optimizer], [scheduler]
|
185 |
|
186 |
-
# def configure_optimizers(self):
|
187 |
-
# optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
188 |
-
# return optimizer
|
189 |
-
|
190 |
-
|
191 |
-
class _TransformerTuner(Transformer):
|
192 |
-
"""
|
193 |
-
When using trainer.tune, batches from dataloader get passed directly to forward,
|
194 |
-
so this subclass takes care of that
|
195 |
-
"""
|
196 |
-
|
197 |
-
def forward(self, batch, batch_idx):
|
198 |
-
src = batch['images']
|
199 |
-
tgt = batch['tex_ids']
|
200 |
-
tgt_input = tgt[:, :-1]
|
201 |
-
tgt_output = tgt[:, 1:]
|
202 |
-
src_mask = None
|
203 |
-
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
204 |
-
torch.ByteTensor.dtype)
|
205 |
-
memory_mask = None
|
206 |
-
src_padding_mask = None
|
207 |
-
tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
208 |
-
tgt_padding_mask = tgt_padding_mask.masked_fill(
|
209 |
-
tgt_padding_mask == 0, float('-inf')
|
210 |
-
).masked_fill(
|
211 |
-
tgt_padding_mask == 1, 0
|
212 |
-
)
|
213 |
-
|
214 |
-
src = self.src_tok_emb(src)
|
215 |
-
tgt_input = self.tgt_tok_emb(tgt_input)
|
216 |
-
outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
217 |
-
outs = self.generator(outs)
|
218 |
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
def
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
@torch.inference_mode()
|
227 |
-
def decode(transformer, tex_tokenizer, image):
|
228 |
-
tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
|
229 |
-
while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
|
230 |
-
src = einops.rearrange(image, "c h w -> () c h w")
|
231 |
-
tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float32)
|
232 |
-
outs = transformer(src, tgt)
|
233 |
-
next_id = outs[:, -1].argmax(dim=1).item()
|
234 |
-
tex_ids.append(next_id)
|
235 |
-
tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
|
236 |
-
return tex
|
|
|
41 |
class ImageEmbedding(nn.Module):
|
42 |
"""Reshape image into patches and project into given dimension"""
|
43 |
|
44 |
+
def __init__(self, d_model, input_width, input_height, patch_size, dropout):
|
45 |
super().__init__()
|
46 |
assert input_width % patch_size == 0 and input_height % patch_size == 0, \
|
47 |
"Cannot split image in patches"
|
|
|
64 |
|
65 |
|
66 |
class TexEmbedding(nn.Module):
|
67 |
+
def __init__(self, d_model: int, vocab_size: int, dropout: float):
|
68 |
super().__init__()
|
69 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
70 |
self.add_positional_encoding = AddPositionalEncoding(d_model)
|
|
|
85 |
|
86 |
def __init__(self, image_width, image_height, d_model, num_layers=8):
|
87 |
super().__init__()
|
88 |
+
image_embedding = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=.1)
|
89 |
encoder_layer = nn.TransformerEncoderLayer(
|
90 |
d_model=d_model,
|
91 |
nhead=8,
|
|
|
111 |
pad_idx: int,
|
112 |
dim_feedforward: int = 512,
|
113 |
dropout: float = .1,
|
|
|
114 |
):
|
115 |
super().__init__()
|
116 |
|
|
|
125 |
if p.dim() > 1:
|
126 |
nn.init.xavier_uniform_(p)
|
127 |
|
128 |
+
self.d_model = emb_size
|
129 |
+
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, patch_size=16, dropout=dropout)
|
130 |
self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
|
131 |
+
self.generator = nn.Linear(emb_size, tgt_vocab_size)
|
132 |
+
# Make embedding and generator share weight because they do the same thing
|
133 |
+
self.tgt_tok_emb.embedding.weight = self.generator.weight
|
134 |
+
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
|
135 |
self.save_hyperparameters()
|
136 |
|
137 |
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
|
|
|
148 |
tgt_input = tgt[:, :-1]
|
149 |
tgt_output = tgt[:, 1:]
|
150 |
src_mask = None
|
151 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device, torch.bool)
|
|
|
152 |
memory_mask = None
|
153 |
src_padding_mask = None
|
154 |
+
tgt_padding_mask = torch.logical_not(batch['tex_attention_masks'][:, :-1])
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
outs = self(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
157 |
loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
|
|
|
174 |
return loss
|
175 |
|
176 |
def configure_optimizers(self):
|
177 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
|
178 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, NoamLRLambda(self.d_model))
|
179 |
return [optimizer], [scheduler]
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
class NoamLRLambda:
|
183 |
+
def __init__(self, d_model, factor=1, warmup=4000):
|
184 |
+
"""
|
185 |
+
:param d_model: size of hidden model dimension
|
186 |
+
:param factor: multiplicative factor
|
187 |
+
:param warmup: number of warmup steps
|
188 |
+
"""
|
189 |
+
self.d_model = d_model
|
190 |
+
self.factor = factor
|
191 |
+
self.warmup = warmup
|
192 |
|
193 |
+
def __call__(self, step):
|
194 |
+
step += 1
|
195 |
+
return self.factor * self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
from data_generator import generate_data
|
2 |
from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
|
3 |
-
from model import Transformer
|
4 |
from utils import LogImageTexCallback
|
5 |
|
6 |
import argparse
|
7 |
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
-
from pytorch_lightning.loggers import
|
9 |
from pytorch_lightning import Trainer, seed_everything
|
10 |
import torch
|
11 |
|
12 |
DATASET_PATH = "resources/dataset.pt"
|
13 |
TRAINER_DIR = "resources/pl_trainer_checkpoints"
|
14 |
TUNER_DIR = "resources/pl_tuner_checkpoints"
|
15 |
-
TRAINER_STRATEGY = "ddp"
|
16 |
BEST_MODEL_CHECKPOINT = "best_model.ckpt"
|
17 |
|
18 |
|
@@ -46,6 +45,10 @@ def parse_args():
|
|
46 |
return args
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
49 |
def main():
|
50 |
args = parse_args()
|
51 |
|
@@ -58,8 +61,6 @@ def main():
|
|
58 |
torch.save(datamodule, DATASET_PATH)
|
59 |
else:
|
60 |
datamodule = torch.load(DATASET_PATH)
|
61 |
-
# TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
|
62 |
-
# determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
|
63 |
|
64 |
if args.log:
|
65 |
logger = WandbLogger(f"img2tex", log_model=True)
|
@@ -75,7 +76,7 @@ def main():
|
|
75 |
accelerator="cpu" if args.gpus is None else "gpu",
|
76 |
gpus=args.gpus,
|
77 |
logger=logger,
|
78 |
-
strategy=
|
79 |
enable_progress_bar=True,
|
80 |
default_root_dir=TRAINER_DIR,
|
81 |
callbacks=callbacks,
|
@@ -91,24 +92,8 @@ def main():
|
|
91 |
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
|
92 |
dim_feedforward=512,
|
93 |
dropout=0.1,
|
94 |
-
learning_rate=1e-3
|
95 |
)
|
96 |
|
97 |
-
# if args.new_dataset:
|
98 |
-
# datamodule.batch_size = 1
|
99 |
-
# transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
|
100 |
-
# tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
|
101 |
-
# gpus=args.gpus,
|
102 |
-
# strategy=TRAINER_STRATEGY,
|
103 |
-
# enable_progress_bar=True,
|
104 |
-
# enable_checkpointing=False,
|
105 |
-
# auto_scale_batch_size=True,
|
106 |
-
# num_sanity_val_steps=0,
|
107 |
-
# logger=False
|
108 |
-
# )
|
109 |
-
# tuner.tune(transformer_for_tuning, datamodule=datamodule)
|
110 |
-
# torch.save(datamodule, DATASET_PATH)
|
111 |
-
|
112 |
trainer.fit(transformer, datamodule=datamodule)
|
113 |
trainer.test(datamodule=datamodule)
|
114 |
trainer.save_checkpoint(BEST_MODEL_CHECKPOINT)
|
|
|
1 |
from data_generator import generate_data
|
2 |
from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
|
3 |
+
from model import Transformer
|
4 |
from utils import LogImageTexCallback
|
5 |
|
6 |
import argparse
|
7 |
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
+
from pytorch_lightning.loggers import WandbLogger
|
9 |
from pytorch_lightning import Trainer, seed_everything
|
10 |
import torch
|
11 |
|
12 |
DATASET_PATH = "resources/dataset.pt"
|
13 |
TRAINER_DIR = "resources/pl_trainer_checkpoints"
|
14 |
TUNER_DIR = "resources/pl_tuner_checkpoints"
|
|
|
15 |
BEST_MODEL_CHECKPOINT = "best_model.ckpt"
|
16 |
|
17 |
|
|
|
45 |
return args
|
46 |
|
47 |
|
48 |
+
# TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
|
49 |
+
# crop image, adjust brightness, lr warmup?, make tex tokens always decodable,
|
50 |
+
# take loss that doesn't punish so much for offsets, take a look at weights,
|
51 |
+
|
52 |
def main():
|
53 |
args = parse_args()
|
54 |
|
|
|
61 |
torch.save(datamodule, DATASET_PATH)
|
62 |
else:
|
63 |
datamodule = torch.load(DATASET_PATH)
|
|
|
|
|
64 |
|
65 |
if args.log:
|
66 |
logger = WandbLogger(f"img2tex", log_model=True)
|
|
|
76 |
accelerator="cpu" if args.gpus is None else "gpu",
|
77 |
gpus=args.gpus,
|
78 |
logger=logger,
|
79 |
+
strategy="ddp",
|
80 |
enable_progress_bar=True,
|
81 |
default_root_dir=TRAINER_DIR,
|
82 |
callbacks=callbacks,
|
|
|
92 |
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
|
93 |
dim_feedforward=512,
|
94 |
dropout=0.1,
|
|
|
95 |
)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
trainer.fit(transformer, datamodule=datamodule)
|
98 |
trainer.test(datamodule=datamodule)
|
99 |
trainer.save_checkpoint(BEST_MODEL_CHECKPOINT)
|
utils.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
import
|
2 |
-
from pytorch_lightning.callbacks import Callback
|
3 |
-
from model import decode
|
4 |
|
|
|
|
|
|
|
|
|
5 |
from torchvision import transforms
|
6 |
|
7 |
|
@@ -14,8 +16,75 @@ class LogImageTexCallback(Callback):
|
|
14 |
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
|
15 |
if batch_idx != 0 or dataloader_idx != 0:
|
16 |
return
|
17 |
-
|
18 |
-
|
|
|
19 |
image = self.tensor_to_PIL(image)
|
20 |
-
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model import Transformer
|
|
|
|
|
2 |
|
3 |
+
import einops
|
4 |
+
import random
|
5 |
+
from pytorch_lightning.callbacks import Callback
|
6 |
+
import torch
|
7 |
from torchvision import transforms
|
8 |
|
9 |
|
|
|
16 |
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
|
17 |
if batch_idx != 0 or dataloader_idx != 0:
|
18 |
return
|
19 |
+
sample_id = random.randint(0, len(batch['images']) - 1)
|
20 |
+
image = batch['images'][sample_id]
|
21 |
+
tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image)
|
22 |
image = self.tensor_to_PIL(image)
|
23 |
+
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
|
24 |
+
skip_special_tokens=True)
|
25 |
+
self.logger.log_image(key="samples", images=[image],
|
26 |
+
caption=[f"True: {tex_true}\nPredicted: {tex_predicted}\nIds: {tex_ids}"])
|
27 |
+
|
28 |
+
# if args.new_dataset:
|
29 |
+
# datamodule.batch_size = 1
|
30 |
+
# transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
|
31 |
+
# tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
|
32 |
+
# gpus=args.gpus,
|
33 |
+
# strategy=TRAINER_STRATEGY,
|
34 |
+
# enable_progress_bar=True,
|
35 |
+
# enable_checkpointing=False,
|
36 |
+
# auto_scale_batch_size=True,
|
37 |
+
# num_sanity_val_steps=0,
|
38 |
+
# logger=False
|
39 |
+
# )
|
40 |
+
# tuner.tune(transformer_for_tuning, datamodule=datamodule)
|
41 |
+
# torch.save(datamodule, DATASET_PATH)
|
42 |
+
|
43 |
+
class _TransformerTuner(Transformer):
|
44 |
+
"""
|
45 |
+
When using trainer.tune, batches from dataloader get passed directly to forward,
|
46 |
+
so this subclass takes care of that
|
47 |
+
"""
|
48 |
+
|
49 |
+
def forward(self, batch, batch_idx):
|
50 |
+
src = batch['images']
|
51 |
+
tgt = batch['tex_ids']
|
52 |
+
tgt_input = tgt[:, :-1]
|
53 |
+
tgt_output = tgt[:, 1:]
|
54 |
+
src_mask = None
|
55 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
56 |
+
torch.ByteTensor.dtype)
|
57 |
+
memory_mask = None
|
58 |
+
src_padding_mask = None
|
59 |
+
tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
60 |
+
tgt_padding_mask = tgt_padding_mask.masked_fill(
|
61 |
+
tgt_padding_mask == 0, float('-inf')
|
62 |
+
).masked_fill(
|
63 |
+
tgt_padding_mask == 1, 0
|
64 |
+
)
|
65 |
+
|
66 |
+
src = self.src_tok_emb(src)
|
67 |
+
tgt_input = self.tgt_tok_emb(tgt_input)
|
68 |
+
outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
69 |
+
outs = self.generator(outs)
|
70 |
+
|
71 |
+
loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
|
72 |
+
return loss
|
73 |
+
|
74 |
+
def validation_step(self, batch, batch_idx):
|
75 |
+
return self(batch, batch_idx)
|
76 |
+
|
77 |
+
@torch.inference_mode()
|
78 |
+
def decode(transformer, tex_tokenizer, image):
|
79 |
+
tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
|
80 |
+
src = einops.rearrange(image, "c h w -> () c h w")
|
81 |
+
while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
|
82 |
+
tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float)
|
83 |
+
tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device,
|
84 |
+
torch.bool)
|
85 |
+
outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask)
|
86 |
+
outs = einops.rearrange(outs, 'b n prob -> b prob n')
|
87 |
+
next_id = outs[0, :, -1].argmax().item()
|
88 |
+
tex_ids.append(next_id)
|
89 |
+
tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
|
90 |
+
return tex, tex_ids
|