dkoshman commited on
Commit
e932abd
·
1 Parent(s): 57273ba

noam lr scheduler, shared weight between embedding and generator

Browse files
Files changed (4) hide show
  1. data_preprocessing.py +6 -4
  2. model.py +26 -67
  3. train.py +7 -22
  4. utils.py +76 -7
data_preprocessing.py CHANGED
@@ -115,8 +115,10 @@ class RandomizeImageTransform(object):
115
  """Standardize image and randomly augment"""
116
 
117
  def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT, random_magnitude=5):
 
 
118
  self.transform = T.Compose((
119
- T.ColorJitter(brightness=random_magnitude / 10),
120
  T.Resize(height),
121
  T.Grayscale(),
122
  T.functional.invert,
@@ -184,12 +186,12 @@ class LatexImageDataModule(pl.LightningDataModule):
184
  )
185
  self.val_dataset = TexImageDataset(
186
  root_dir=DATA_DIR,
187
- image_transform=StandardizeImageTransform(),
188
  tex_transform=ExtractEquationFromTexTransform()
189
  )
190
  self.test_dataset = TexImageDataset(
191
  root_dir=DATA_DIR,
192
- image_transform=StandardizeImageTransform(),
193
  tex_transform=ExtractEquationFromTexTransform()
194
  )
195
  train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
@@ -215,7 +217,7 @@ class LatexImageDataModule(pl.LightningDataModule):
215
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
216
 
217
  def val_dataloader(self):
218
- return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn,
219
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
220
 
221
  def test_dataloader(self):
 
115
  """Standardize image and randomly augment"""
116
 
117
  def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT, random_magnitude=5):
118
+ assert random_magnitude > 0
119
+ eps = 0.01
120
  self.transform = T.Compose((
121
+ T.ColorJitter(brightness=((1 - eps) / (random_magnitude + eps), 1 - eps)),
122
  T.Resize(height),
123
  T.Grayscale(),
124
  T.functional.invert,
 
186
  )
187
  self.val_dataset = TexImageDataset(
188
  root_dir=DATA_DIR,
189
+ image_transform=RandomizeImageTransform(),
190
  tex_transform=ExtractEquationFromTexTransform()
191
  )
192
  self.test_dataset = TexImageDataset(
193
  root_dir=DATA_DIR,
194
+ image_transform=RandomizeImageTransform(),
195
  tex_transform=ExtractEquationFromTexTransform()
196
  )
197
  train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
 
217
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
218
 
219
  def val_dataloader(self):
220
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
221
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
222
 
223
  def test_dataloader(self):
model.py CHANGED
@@ -41,7 +41,7 @@ class AddPositionalEncoding(nn.Module):
41
  class ImageEmbedding(nn.Module):
42
  """Reshape image into patches and project into given dimension"""
43
 
44
- def __init__(self, d_model, input_width, input_height, patch_size=16, dropout=.1):
45
  super().__init__()
46
  assert input_width % patch_size == 0 and input_height % patch_size == 0, \
47
  "Cannot split image in patches"
@@ -64,7 +64,7 @@ class ImageEmbedding(nn.Module):
64
 
65
 
66
  class TexEmbedding(nn.Module):
67
- def __init__(self, d_model: int, vocab_size: int, dropout: float = .1):
68
  super().__init__()
69
  self.embedding = nn.Embedding(vocab_size, d_model)
70
  self.add_positional_encoding = AddPositionalEncoding(d_model)
@@ -85,7 +85,7 @@ class ImageEncoder(nn.Module):
85
 
86
  def __init__(self, image_width, image_height, d_model, num_layers=8):
87
  super().__init__()
88
- image_embedding = ImageEmbedding(d_model, image_width, image_height)
89
  encoder_layer = nn.TransformerEncoderLayer(
90
  d_model=d_model,
91
  nhead=8,
@@ -111,7 +111,6 @@ class Transformer(pl.LightningModule):
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
  dropout: float = .1,
114
- learning_rate: float = 1e-3
115
  ):
116
  super().__init__()
117
 
@@ -126,11 +125,13 @@ class Transformer(pl.LightningModule):
126
  if p.dim() > 1:
127
  nn.init.xavier_uniform_(p)
128
 
129
- self.generator = nn.Linear(emb_size, tgt_vocab_size)
130
- self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
131
  self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
132
- self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
133
- self.learning_rate = learning_rate
 
 
134
  self.save_hyperparameters()
135
 
136
  def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
@@ -147,16 +148,10 @@ class Transformer(pl.LightningModule):
147
  tgt_input = tgt[:, :-1]
148
  tgt_output = tgt[:, 1:]
149
  src_mask = None
150
- tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
151
- torch.ByteTensor.dtype)
152
  memory_mask = None
153
  src_padding_mask = None
154
- tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
155
- tgt_padding_mask = tgt_padding_mask.masked_fill(
156
- tgt_padding_mask == 0, float('-inf')
157
- ).masked_fill(
158
- tgt_padding_mask == 1, 0
159
- )
160
 
161
  outs = self(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
162
  loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
@@ -179,58 +174,22 @@ class Transformer(pl.LightningModule):
179
  return loss
180
 
181
  def configure_optimizers(self):
182
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
183
- scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1)
184
  return [optimizer], [scheduler]
185
 
186
- # def configure_optimizers(self):
187
- # optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
188
- # return optimizer
189
-
190
-
191
- class _TransformerTuner(Transformer):
192
- """
193
- When using trainer.tune, batches from dataloader get passed directly to forward,
194
- so this subclass takes care of that
195
- """
196
-
197
- def forward(self, batch, batch_idx):
198
- src = batch['images']
199
- tgt = batch['tex_ids']
200
- tgt_input = tgt[:, :-1]
201
- tgt_output = tgt[:, 1:]
202
- src_mask = None
203
- tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
204
- torch.ByteTensor.dtype)
205
- memory_mask = None
206
- src_padding_mask = None
207
- tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
208
- tgt_padding_mask = tgt_padding_mask.masked_fill(
209
- tgt_padding_mask == 0, float('-inf')
210
- ).masked_fill(
211
- tgt_padding_mask == 1, 0
212
- )
213
-
214
- src = self.src_tok_emb(src)
215
- tgt_input = self.tgt_tok_emb(tgt_input)
216
- outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
217
- outs = self.generator(outs)
218
 
219
- loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
220
- return loss
 
 
 
 
 
 
 
 
221
 
222
- def validation_step(self, batch, batch_idx):
223
- return self(batch, batch_idx)
224
-
225
-
226
- @torch.inference_mode()
227
- def decode(transformer, tex_tokenizer, image):
228
- tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
229
- while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
230
- src = einops.rearrange(image, "c h w -> () c h w")
231
- tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float32)
232
- outs = transformer(src, tgt)
233
- next_id = outs[:, -1].argmax(dim=1).item()
234
- tex_ids.append(next_id)
235
- tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
236
- return tex
 
41
  class ImageEmbedding(nn.Module):
42
  """Reshape image into patches and project into given dimension"""
43
 
44
+ def __init__(self, d_model, input_width, input_height, patch_size, dropout):
45
  super().__init__()
46
  assert input_width % patch_size == 0 and input_height % patch_size == 0, \
47
  "Cannot split image in patches"
 
64
 
65
 
66
  class TexEmbedding(nn.Module):
67
+ def __init__(self, d_model: int, vocab_size: int, dropout: float):
68
  super().__init__()
69
  self.embedding = nn.Embedding(vocab_size, d_model)
70
  self.add_positional_encoding = AddPositionalEncoding(d_model)
 
85
 
86
  def __init__(self, image_width, image_height, d_model, num_layers=8):
87
  super().__init__()
88
+ image_embedding = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=.1)
89
  encoder_layer = nn.TransformerEncoderLayer(
90
  d_model=d_model,
91
  nhead=8,
 
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
  dropout: float = .1,
 
114
  ):
115
  super().__init__()
116
 
 
125
  if p.dim() > 1:
126
  nn.init.xavier_uniform_(p)
127
 
128
+ self.d_model = emb_size
129
+ self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, patch_size=16, dropout=dropout)
130
  self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
131
+ self.generator = nn.Linear(emb_size, tgt_vocab_size)
132
+ # Make embedding and generator share weight because they do the same thing
133
+ self.tgt_tok_emb.embedding.weight = self.generator.weight
134
+ self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
135
  self.save_hyperparameters()
136
 
137
  def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
 
148
  tgt_input = tgt[:, :-1]
149
  tgt_output = tgt[:, 1:]
150
  src_mask = None
151
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device, torch.bool)
 
152
  memory_mask = None
153
  src_padding_mask = None
154
+ tgt_padding_mask = torch.logical_not(batch['tex_attention_masks'][:, :-1])
 
 
 
 
 
155
 
156
  outs = self(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
157
  loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
 
174
  return loss
175
 
176
  def configure_optimizers(self):
177
+ optimizer = torch.optim.Adam(self.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
178
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, NoamLRLambda(self.d_model))
179
  return [optimizer], [scheduler]
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ class NoamLRLambda:
183
+ def __init__(self, d_model, factor=1, warmup=4000):
184
+ """
185
+ :param d_model: size of hidden model dimension
186
+ :param factor: multiplicative factor
187
+ :param warmup: number of warmup steps
188
+ """
189
+ self.d_model = d_model
190
+ self.factor = factor
191
+ self.warmup = warmup
192
 
193
+ def __call__(self, step):
194
+ step += 1
195
+ return self.factor * self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -1,18 +1,17 @@
1
  from data_generator import generate_data
2
  from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
- from model import Transformer, _TransformerTuner
4
  from utils import LogImageTexCallback
5
 
6
  import argparse
7
  from pytorch_lightning.callbacks import LearningRateMonitor
8
- from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
9
  from pytorch_lightning import Trainer, seed_everything
10
  import torch
11
 
12
  DATASET_PATH = "resources/dataset.pt"
13
  TRAINER_DIR = "resources/pl_trainer_checkpoints"
14
  TUNER_DIR = "resources/pl_tuner_checkpoints"
15
- TRAINER_STRATEGY = "ddp"
16
  BEST_MODEL_CHECKPOINT = "best_model.ckpt"
17
 
18
 
@@ -46,6 +45,10 @@ def parse_args():
46
  return args
47
 
48
 
 
 
 
 
49
  def main():
50
  args = parse_args()
51
 
@@ -58,8 +61,6 @@ def main():
58
  torch.save(datamodule, DATASET_PATH)
59
  else:
60
  datamodule = torch.load(DATASET_PATH)
61
- # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
62
- # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
63
 
64
  if args.log:
65
  logger = WandbLogger(f"img2tex", log_model=True)
@@ -75,7 +76,7 @@ def main():
75
  accelerator="cpu" if args.gpus is None else "gpu",
76
  gpus=args.gpus,
77
  logger=logger,
78
- strategy=TRAINER_STRATEGY,
79
  enable_progress_bar=True,
80
  default_root_dir=TRAINER_DIR,
81
  callbacks=callbacks,
@@ -91,24 +92,8 @@ def main():
91
  pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
92
  dim_feedforward=512,
93
  dropout=0.1,
94
- learning_rate=1e-3
95
  )
96
 
97
- # if args.new_dataset:
98
- # datamodule.batch_size = 1
99
- # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
100
- # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
101
- # gpus=args.gpus,
102
- # strategy=TRAINER_STRATEGY,
103
- # enable_progress_bar=True,
104
- # enable_checkpointing=False,
105
- # auto_scale_batch_size=True,
106
- # num_sanity_val_steps=0,
107
- # logger=False
108
- # )
109
- # tuner.tune(transformer_for_tuning, datamodule=datamodule)
110
- # torch.save(datamodule, DATASET_PATH)
111
-
112
  trainer.fit(transformer, datamodule=datamodule)
113
  trainer.test(datamodule=datamodule)
114
  trainer.save_checkpoint(BEST_MODEL_CHECKPOINT)
 
1
  from data_generator import generate_data
2
  from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
+ from model import Transformer
4
  from utils import LogImageTexCallback
5
 
6
  import argparse
7
  from pytorch_lightning.callbacks import LearningRateMonitor
8
+ from pytorch_lightning.loggers import WandbLogger
9
  from pytorch_lightning import Trainer, seed_everything
10
  import torch
11
 
12
  DATASET_PATH = "resources/dataset.pt"
13
  TRAINER_DIR = "resources/pl_trainer_checkpoints"
14
  TUNER_DIR = "resources/pl_tuner_checkpoints"
 
15
  BEST_MODEL_CHECKPOINT = "best_model.ckpt"
16
 
17
 
 
45
  return args
46
 
47
 
48
+ # TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
49
+ # crop image, adjust brightness, lr warmup?, make tex tokens always decodable,
50
+ # take loss that doesn't punish so much for offsets, take a look at weights,
51
+
52
  def main():
53
  args = parse_args()
54
 
 
61
  torch.save(datamodule, DATASET_PATH)
62
  else:
63
  datamodule = torch.load(DATASET_PATH)
 
 
64
 
65
  if args.log:
66
  logger = WandbLogger(f"img2tex", log_model=True)
 
76
  accelerator="cpu" if args.gpus is None else "gpu",
77
  gpus=args.gpus,
78
  logger=logger,
79
+ strategy="ddp",
80
  enable_progress_bar=True,
81
  default_root_dir=TRAINER_DIR,
82
  callbacks=callbacks,
 
92
  pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
93
  dim_feedforward=512,
94
  dropout=0.1,
 
95
  )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  trainer.fit(transformer, datamodule=datamodule)
98
  trainer.test(datamodule=datamodule)
99
  trainer.save_checkpoint(BEST_MODEL_CHECKPOINT)
utils.py CHANGED
@@ -1,7 +1,9 @@
1
- import torch
2
- from pytorch_lightning.callbacks import Callback
3
- from model import decode
4
 
 
 
 
 
5
  from torchvision import transforms
6
 
7
 
@@ -14,8 +16,75 @@ class LogImageTexCallback(Callback):
14
  def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
15
  if batch_idx != 0 or dataloader_idx != 0:
16
  return
17
- image = batch['images'][0]
18
- tex_predicted = decode(transformer, self.tex_tokenizer, image)
 
19
  image = self.tensor_to_PIL(image)
20
- tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][0].to('cpu', torch.int)), skip_special_tokens=True)
21
- self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\n Predicted: {tex_predicted}"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import Transformer
 
 
2
 
3
+ import einops
4
+ import random
5
+ from pytorch_lightning.callbacks import Callback
6
+ import torch
7
  from torchvision import transforms
8
 
9
 
 
16
  def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
17
  if batch_idx != 0 or dataloader_idx != 0:
18
  return
19
+ sample_id = random.randint(0, len(batch['images']) - 1)
20
+ image = batch['images'][sample_id]
21
+ tex_predicted, tex_ids = decode(transformer, self.tex_tokenizer, image)
22
  image = self.tensor_to_PIL(image)
23
+ tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
24
+ skip_special_tokens=True)
25
+ self.logger.log_image(key="samples", images=[image],
26
+ caption=[f"True: {tex_true}\nPredicted: {tex_predicted}\nIds: {tex_ids}"])
27
+
28
+ # if args.new_dataset:
29
+ # datamodule.batch_size = 1
30
+ # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
31
+ # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
32
+ # gpus=args.gpus,
33
+ # strategy=TRAINER_STRATEGY,
34
+ # enable_progress_bar=True,
35
+ # enable_checkpointing=False,
36
+ # auto_scale_batch_size=True,
37
+ # num_sanity_val_steps=0,
38
+ # logger=False
39
+ # )
40
+ # tuner.tune(transformer_for_tuning, datamodule=datamodule)
41
+ # torch.save(datamodule, DATASET_PATH)
42
+
43
+ class _TransformerTuner(Transformer):
44
+ """
45
+ When using trainer.tune, batches from dataloader get passed directly to forward,
46
+ so this subclass takes care of that
47
+ """
48
+
49
+ def forward(self, batch, batch_idx):
50
+ src = batch['images']
51
+ tgt = batch['tex_ids']
52
+ tgt_input = tgt[:, :-1]
53
+ tgt_output = tgt[:, 1:]
54
+ src_mask = None
55
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
56
+ torch.ByteTensor.dtype)
57
+ memory_mask = None
58
+ src_padding_mask = None
59
+ tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
60
+ tgt_padding_mask = tgt_padding_mask.masked_fill(
61
+ tgt_padding_mask == 0, float('-inf')
62
+ ).masked_fill(
63
+ tgt_padding_mask == 1, 0
64
+ )
65
+
66
+ src = self.src_tok_emb(src)
67
+ tgt_input = self.tgt_tok_emb(tgt_input)
68
+ outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
69
+ outs = self.generator(outs)
70
+
71
+ loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
72
+ return loss
73
+
74
+ def validation_step(self, batch, batch_idx):
75
+ return self(batch, batch_idx)
76
+
77
+ @torch.inference_mode()
78
+ def decode(transformer, tex_tokenizer, image):
79
+ tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
80
+ src = einops.rearrange(image, "c h w -> () c h w")
81
+ while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
82
+ tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float)
83
+ tgt_mask = transformer.transformer.generate_square_subsequent_mask(tgt.shape[1]).to(transformer.device,
84
+ torch.bool)
85
+ outs = transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask)
86
+ outs = einops.rearrange(outs, 'b n prob -> b prob n')
87
+ next_id = outs[0, :, -1].argmax().item()
88
+ tex_ids.append(next_id)
89
+ tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
90
+ return tex, tex_ids