Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
·
c2ef1c6
1
Parent(s):
c7f2652
tuned cli, added tuner
Browse files- data_preprocessing.py +12 -12
- model.py +5 -2
- train.py +45 -22
- utils.py +0 -0
data_preprocessing.py
CHANGED
@@ -15,8 +15,10 @@ import re
|
|
15 |
TEX_VOCAB_SIZE = 300
|
16 |
IMAGE_WIDTH = 1024
|
17 |
IMAGE_HEIGHT = 128
|
18 |
-
BATCH_SIZE =
|
19 |
-
NUM_WORKERS =
|
|
|
|
|
20 |
|
21 |
|
22 |
class TexImageDataset(Dataset):
|
@@ -170,9 +172,10 @@ def generate_tex_tokenizer(dataset, vocab_size):
|
|
170 |
|
171 |
|
172 |
class LatexImageDataModule(pl.LightningDataModule):
|
173 |
-
def __init__(self):
|
174 |
super().__init__()
|
175 |
torch.manual_seed(0)
|
|
|
176 |
|
177 |
self.train_dataset = TexImageDataset(
|
178 |
root_dir=DATA_DIR,
|
@@ -206,16 +209,13 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
206 |
return indices[:train_split], indices[train_split: val_split], indices[val_split:]
|
207 |
|
208 |
def train_dataloader(self):
|
209 |
-
return DataLoader(self.train_dataset, batch_size=
|
210 |
-
num_workers=NUM_WORKERS, )
|
211 |
-
# pin_memory=True, persistent_workers=True)
|
212 |
|
213 |
def val_dataloader(self):
|
214 |
-
return DataLoader(self.val_dataset, batch_size=
|
215 |
-
num_workers=NUM_WORKERS, )
|
216 |
-
# pin_memory=True, persistent_workers=True)
|
217 |
|
218 |
def test_dataloader(self):
|
219 |
-
return DataLoader(self.test_dataset, batch_size=
|
220 |
-
num_workers=NUM_WORKERS, )
|
221 |
-
# pin_memory=True, persistent_workers=True)
|
|
|
15 |
TEX_VOCAB_SIZE = 300
|
16 |
IMAGE_WIDTH = 1024
|
17 |
IMAGE_HEIGHT = 128
|
18 |
+
BATCH_SIZE = 8
|
19 |
+
NUM_WORKERS = 4
|
20 |
+
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
21 |
+
PIN_MEMORY = False # probably causes cuda oom error if True
|
22 |
|
23 |
|
24 |
class TexImageDataset(Dataset):
|
|
|
172 |
|
173 |
|
174 |
class LatexImageDataModule(pl.LightningDataModule):
|
175 |
+
def __init__(self, batch_size=BATCH_SIZE):
|
176 |
super().__init__()
|
177 |
torch.manual_seed(0)
|
178 |
+
self.batch_size = batch_size
|
179 |
|
180 |
self.train_dataset = TexImageDataset(
|
181 |
root_dir=DATA_DIR,
|
|
|
209 |
return indices[:train_split], indices[train_split: val_split], indices[val_split:]
|
210 |
|
211 |
def train_dataloader(self):
|
212 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn,
|
213 |
+
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
|
|
214 |
|
215 |
def val_dataloader(self):
|
216 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
|
217 |
+
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
|
|
218 |
|
219 |
def test_dataloader(self):
|
220 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
|
221 |
+
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
|
model.py
CHANGED
@@ -110,8 +110,10 @@ class Transformer(pl.LightningModule):
|
|
110 |
tgt_vocab_size: int,
|
111 |
pad_idx: int,
|
112 |
dim_feedforward: int = 512,
|
113 |
-
dropout: float = .1
|
|
|
114 |
super().__init__()
|
|
|
115 |
self.transformer = nn.Transformer(d_model=emb_size,
|
116 |
nhead=nhead,
|
117 |
num_encoder_layers=num_encoder_layers,
|
@@ -127,6 +129,7 @@ class Transformer(pl.LightningModule):
|
|
127 |
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
|
128 |
self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
|
129 |
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
|
|
|
130 |
|
131 |
def forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask):
|
132 |
src = self.src_tok_emb(src)
|
@@ -174,4 +177,4 @@ class Transformer(pl.LightningModule):
|
|
174 |
|
175 |
def configure_optimizers(self):
|
176 |
# TODO write scheduler
|
177 |
-
return torch.optim.Adam(self.parameters(), lr=
|
|
|
110 |
tgt_vocab_size: int,
|
111 |
pad_idx: int,
|
112 |
dim_feedforward: int = 512,
|
113 |
+
dropout: float = .1,
|
114 |
+
learning_rate=1e-4):
|
115 |
super().__init__()
|
116 |
+
|
117 |
self.transformer = nn.Transformer(d_model=emb_size,
|
118 |
nhead=nhead,
|
119 |
num_encoder_layers=num_encoder_layers,
|
|
|
129 |
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
|
130 |
self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
|
131 |
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
|
132 |
+
self.learning_rate = learning_rate
|
133 |
|
134 |
def forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask):
|
135 |
src = self.src_tok_emb(src)
|
|
|
177 |
|
178 |
def configure_optimizers(self):
|
179 |
# TODO write scheduler
|
180 |
+
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
train.py
CHANGED
@@ -3,32 +3,44 @@ from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
|
|
3 |
from model import Transformer
|
4 |
|
5 |
import argparse
|
6 |
-
import pytorch_lightning as pl
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
|
|
8 |
import torch
|
9 |
|
10 |
DATASET_PATH = 'resources/dataset.pt'
|
11 |
|
12 |
|
13 |
-
def
|
14 |
-
parser = argparse.ArgumentParser(
|
15 |
parser.add_argument(
|
16 |
-
"epochs", help="number of epochs
|
17 |
)
|
18 |
parser.add_argument(
|
19 |
"-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
|
20 |
dest="new_dataset"
|
21 |
)
|
22 |
parser.add_argument(
|
23 |
-
"-g", "-gpus",
|
24 |
-
type=int,
|
25 |
)
|
26 |
parser.add_argument(
|
27 |
"-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
|
28 |
action="store_true", dest="log"
|
29 |
)
|
|
|
|
|
|
|
|
|
30 |
|
31 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
if args.new_dataset is not None:
|
34 |
generate_data(args.new_dataset)
|
@@ -39,25 +51,36 @@ def main():
|
|
39 |
|
40 |
# TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
|
41 |
# determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
|
42 |
-
logger = WandbLogger(f"img2tex_epochs{args.epochs}_size{len(datamodule)}_gpus{args.gpus}_v0") if args.log else None
|
43 |
-
trainer = pl.Trainer(max_epochs=args.epochs, accelerator='gpu' if args.gpus else 'cpu', gpus=args.gpus,
|
44 |
-
logger=logger, strategy='ddp_spawn')
|
45 |
-
transformer = Transformer(
|
46 |
-
num_encoder_layers=3,
|
47 |
-
num_decoder_layers=3,
|
48 |
-
emb_size=512,
|
49 |
-
nhead=8,
|
50 |
-
image_width=IMAGE_WIDTH,
|
51 |
-
image_height=IMAGE_HEIGHT,
|
52 |
-
tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
|
53 |
-
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
|
54 |
-
dim_feedforward=512,
|
55 |
-
dropout=0.1
|
56 |
-
)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
trainer.fit(transformer, datamodule=datamodule)
|
59 |
-
# trainer.validate(datamodule=datamodule)
|
60 |
trainer.test(datamodule=datamodule)
|
|
|
61 |
|
62 |
|
63 |
if __name__ == '__main__':
|
|
|
3 |
from model import Transformer
|
4 |
|
5 |
import argparse
|
|
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
+
from pytorch_lightning import Trainer, seed_everything
|
8 |
import torch
|
9 |
|
10 |
DATASET_PATH = 'resources/dataset.pt'
|
11 |
|
12 |
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument(
|
16 |
+
"-m", "-max-epochs", help="limit the number of training epochs", type=int, dest='max_epochs'
|
17 |
)
|
18 |
parser.add_argument(
|
19 |
"-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
|
20 |
dest="new_dataset"
|
21 |
)
|
22 |
parser.add_argument(
|
23 |
+
"-g", "-gpus", help=f"number of gpus to train on in range 0..{torch.cuda.device_count()}",
|
24 |
+
type=int, dest="gpus", choices=list(range(torch.cuda.device_count())),
|
25 |
)
|
26 |
parser.add_argument(
|
27 |
"-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
|
28 |
action="store_true", dest="log"
|
29 |
)
|
30 |
+
parser.add_argument(
|
31 |
+
"-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
|
32 |
+
action="store_true", dest="deterministic"
|
33 |
+
)
|
34 |
|
35 |
args = parser.parse_args()
|
36 |
+
return args
|
37 |
+
|
38 |
+
|
39 |
+
def main():
|
40 |
+
args = parse_args()
|
41 |
+
|
42 |
+
if args.deterministic:
|
43 |
+
seed_everything(42, workers=True)
|
44 |
|
45 |
if args.new_dataset is not None:
|
46 |
generate_data(args.new_dataset)
|
|
|
51 |
|
52 |
# TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
|
53 |
# determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
logger = WandbLogger(f"img2tex", version='0') if args.log else False
|
56 |
+
|
57 |
+
trainer = Trainer(max_epochs=args.max_epochs,
|
58 |
+
accelerator='gpu' if args.gpus else 'cpu',
|
59 |
+
gpus=args.gpus,
|
60 |
+
logger=logger,
|
61 |
+
strategy='ddp',
|
62 |
+
auto_scale_batch_size="power",
|
63 |
+
auto_lr_find=True,
|
64 |
+
auto_select_gpus=True,
|
65 |
+
enable_progress_bar=True
|
66 |
+
)
|
67 |
+
|
68 |
+
transformer = Transformer(num_encoder_layers=3,
|
69 |
+
num_decoder_layers=3,
|
70 |
+
emb_size=512,
|
71 |
+
nhead=8,
|
72 |
+
image_width=IMAGE_WIDTH,
|
73 |
+
image_height=IMAGE_HEIGHT,
|
74 |
+
tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
|
75 |
+
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
|
76 |
+
dim_feedforward=512,
|
77 |
+
dropout=0.1
|
78 |
+
)
|
79 |
+
|
80 |
+
trainer.tune(transformer, datamodule=datamodule)
|
81 |
trainer.fit(transformer, datamodule=datamodule)
|
|
|
82 |
trainer.test(datamodule=datamodule)
|
83 |
+
trainer.save_checkpoint("best_model.ckpt")
|
84 |
|
85 |
|
86 |
if __name__ == '__main__':
|
utils.py
ADDED
File without changes
|