dkoshman commited on
Commit
c2ef1c6
·
1 Parent(s): c7f2652

tuned cli, added tuner

Browse files
Files changed (4) hide show
  1. data_preprocessing.py +12 -12
  2. model.py +5 -2
  3. train.py +45 -22
  4. utils.py +0 -0
data_preprocessing.py CHANGED
@@ -15,8 +15,10 @@ import re
15
  TEX_VOCAB_SIZE = 300
16
  IMAGE_WIDTH = 1024
17
  IMAGE_HEIGHT = 128
18
- BATCH_SIZE = 16
19
- NUM_WORKERS = 0
 
 
20
 
21
 
22
  class TexImageDataset(Dataset):
@@ -170,9 +172,10 @@ def generate_tex_tokenizer(dataset, vocab_size):
170
 
171
 
172
  class LatexImageDataModule(pl.LightningDataModule):
173
- def __init__(self):
174
  super().__init__()
175
  torch.manual_seed(0)
 
176
 
177
  self.train_dataset = TexImageDataset(
178
  root_dir=DATA_DIR,
@@ -206,16 +209,13 @@ class LatexImageDataModule(pl.LightningDataModule):
206
  return indices[:train_split], indices[train_split: val_split], indices[val_split:]
207
 
208
  def train_dataloader(self):
209
- return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=self.collate_fn,
210
- num_workers=NUM_WORKERS, )
211
- # pin_memory=True, persistent_workers=True)
212
 
213
  def val_dataloader(self):
214
- return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, collate_fn=self.collate_fn,
215
- num_workers=NUM_WORKERS, )
216
- # pin_memory=True, persistent_workers=True)
217
 
218
  def test_dataloader(self):
219
- return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, collate_fn=self.collate_fn,
220
- num_workers=NUM_WORKERS, )
221
- # pin_memory=True, persistent_workers=True)
 
15
  TEX_VOCAB_SIZE = 300
16
  IMAGE_WIDTH = 1024
17
  IMAGE_HEIGHT = 128
18
+ BATCH_SIZE = 8
19
+ NUM_WORKERS = 4
20
+ PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
21
+ PIN_MEMORY = False # probably causes cuda oom error if True
22
 
23
 
24
  class TexImageDataset(Dataset):
 
172
 
173
 
174
  class LatexImageDataModule(pl.LightningDataModule):
175
+ def __init__(self, batch_size=BATCH_SIZE):
176
  super().__init__()
177
  torch.manual_seed(0)
178
+ self.batch_size = batch_size
179
 
180
  self.train_dataset = TexImageDataset(
181
  root_dir=DATA_DIR,
 
209
  return indices[:train_split], indices[train_split: val_split], indices[val_split:]
210
 
211
  def train_dataloader(self):
212
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn,
213
+ pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
 
214
 
215
  def val_dataloader(self):
216
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
217
+ pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
 
218
 
219
  def test_dataloader(self):
220
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
221
+ pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
 
model.py CHANGED
@@ -110,8 +110,10 @@ class Transformer(pl.LightningModule):
110
  tgt_vocab_size: int,
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
- dropout: float = .1):
 
114
  super().__init__()
 
115
  self.transformer = nn.Transformer(d_model=emb_size,
116
  nhead=nhead,
117
  num_encoder_layers=num_encoder_layers,
@@ -127,6 +129,7 @@ class Transformer(pl.LightningModule):
127
  self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
128
  self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
129
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
 
130
 
131
  def forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask):
132
  src = self.src_tok_emb(src)
@@ -174,4 +177,4 @@ class Transformer(pl.LightningModule):
174
 
175
  def configure_optimizers(self):
176
  # TODO write scheduler
177
- return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
 
110
  tgt_vocab_size: int,
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
+ dropout: float = .1,
114
+ learning_rate=1e-4):
115
  super().__init__()
116
+
117
  self.transformer = nn.Transformer(d_model=emb_size,
118
  nhead=nhead,
119
  num_encoder_layers=num_encoder_layers,
 
129
  self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
130
  self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
131
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
132
+ self.learning_rate = learning_rate
133
 
134
  def forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask):
135
  src = self.src_tok_emb(src)
 
177
 
178
  def configure_optimizers(self):
179
  # TODO write scheduler
180
+ return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
train.py CHANGED
@@ -3,32 +3,44 @@ from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
  from model import Transformer
4
 
5
  import argparse
6
- import pytorch_lightning as pl
7
  from pytorch_lightning.loggers import WandbLogger
 
8
  import torch
9
 
10
  DATASET_PATH = 'resources/dataset.pt'
11
 
12
 
13
- def main():
14
- parser = argparse.ArgumentParser("Trainer", usage="run trainer")
15
  parser.add_argument(
16
- "epochs", help="number of epochs to train", type=int
17
  )
18
  parser.add_argument(
19
  "-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
20
  dest="new_dataset"
21
  )
22
  parser.add_argument(
23
- "-g", "-gpus", metavar="GPUS", help=f"list of gpu ids to train on in range 0..{torch.cuda.device_count()}",
24
- type=int, nargs='+', dest="gpus", choices=list(range(torch.cuda.device_count())),
25
  )
26
  parser.add_argument(
27
  "-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
28
  action="store_true", dest="log"
29
  )
 
 
 
 
30
 
31
  args = parser.parse_args()
 
 
 
 
 
 
 
 
32
 
33
  if args.new_dataset is not None:
34
  generate_data(args.new_dataset)
@@ -39,25 +51,36 @@ def main():
39
 
40
  # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
41
  # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
42
- logger = WandbLogger(f"img2tex_epochs{args.epochs}_size{len(datamodule)}_gpus{args.gpus}_v0") if args.log else None
43
- trainer = pl.Trainer(max_epochs=args.epochs, accelerator='gpu' if args.gpus else 'cpu', gpus=args.gpus,
44
- logger=logger, strategy='ddp_spawn')
45
- transformer = Transformer(
46
- num_encoder_layers=3,
47
- num_decoder_layers=3,
48
- emb_size=512,
49
- nhead=8,
50
- image_width=IMAGE_WIDTH,
51
- image_height=IMAGE_HEIGHT,
52
- tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
53
- pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
54
- dim_feedforward=512,
55
- dropout=0.1
56
- )
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  trainer.fit(transformer, datamodule=datamodule)
59
- # trainer.validate(datamodule=datamodule)
60
  trainer.test(datamodule=datamodule)
 
61
 
62
 
63
  if __name__ == '__main__':
 
3
  from model import Transformer
4
 
5
  import argparse
 
6
  from pytorch_lightning.loggers import WandbLogger
7
+ from pytorch_lightning import Trainer, seed_everything
8
  import torch
9
 
10
  DATASET_PATH = 'resources/dataset.pt'
11
 
12
 
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser()
15
  parser.add_argument(
16
+ "-m", "-max-epochs", help="limit the number of training epochs", type=int, dest='max_epochs'
17
  )
18
  parser.add_argument(
19
  "-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
20
  dest="new_dataset"
21
  )
22
  parser.add_argument(
23
+ "-g", "-gpus", help=f"number of gpus to train on in range 0..{torch.cuda.device_count()}",
24
+ type=int, dest="gpus", choices=list(range(torch.cuda.device_count())),
25
  )
26
  parser.add_argument(
27
  "-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
28
  action="store_true", dest="log"
29
  )
30
+ parser.add_argument(
31
+ "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
32
+ action="store_true", dest="deterministic"
33
+ )
34
 
35
  args = parser.parse_args()
36
+ return args
37
+
38
+
39
+ def main():
40
+ args = parse_args()
41
+
42
+ if args.deterministic:
43
+ seed_everything(42, workers=True)
44
 
45
  if args.new_dataset is not None:
46
  generate_data(args.new_dataset)
 
51
 
52
  # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
53
  # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ logger = WandbLogger(f"img2tex", version='0') if args.log else False
56
+
57
+ trainer = Trainer(max_epochs=args.max_epochs,
58
+ accelerator='gpu' if args.gpus else 'cpu',
59
+ gpus=args.gpus,
60
+ logger=logger,
61
+ strategy='ddp',
62
+ auto_scale_batch_size="power",
63
+ auto_lr_find=True,
64
+ auto_select_gpus=True,
65
+ enable_progress_bar=True
66
+ )
67
+
68
+ transformer = Transformer(num_encoder_layers=3,
69
+ num_decoder_layers=3,
70
+ emb_size=512,
71
+ nhead=8,
72
+ image_width=IMAGE_WIDTH,
73
+ image_height=IMAGE_HEIGHT,
74
+ tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
75
+ pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
76
+ dim_feedforward=512,
77
+ dropout=0.1
78
+ )
79
+
80
+ trainer.tune(transformer, datamodule=datamodule)
81
  trainer.fit(transformer, datamodule=datamodule)
 
82
  trainer.test(datamodule=datamodule)
83
+ trainer.save_checkpoint("best_model.ckpt")
84
 
85
 
86
  if __name__ == '__main__':
utils.py ADDED
File without changes