dkoshman commited on
Commit
fb8db0f
·
1 Parent(s): 8ab1767

big changes: transformer, pytorch lightning, argparser

Browse files
Files changed (4) hide show
  1. data_generator.py +54 -53
  2. data_preprocessing.py +68 -32
  3. model.py +54 -26
  4. train.py +44 -22
data_generator.py CHANGED
@@ -1,10 +1,13 @@
 
 
1
  import json
2
  from multiprocessing import Pool
3
  import os
 
4
  import string
5
  import subprocess
6
  import random
7
- from typing import Iterable
8
 
9
 
10
  class DotDict(dict):
@@ -102,7 +105,7 @@ def generate_equation(latex: DotDict, size, depth=3):
102
  return equation
103
 
104
 
105
- def generate_image(directory: str, latex: DotDict, filename: str, max_length=20, equation_depth=3,
106
  pdflatex: str = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex",
107
  ghostscript: str = "/external2/dkkoshman/venv/local/gs/bin/gs"
108
  ):
@@ -111,18 +114,16 @@ def generate_image(directory: str, latex: DotDict, filename: str, max_length=20,
111
  -------
112
  params:
113
  :directory: -- dir where to save files
114
- :latex: -- DotDict with parameters to generate tex
115
  :filename: -- absolute filename for the generated files
116
  :max_length: -- max size of equation
117
  :equation_depth: -- max nested level of tex scopes
118
  :pdflatex: -- path to pdflatex
119
  :ghostscript: -- path to ghostscript
120
  """
121
- # TODO ARGPARSE
122
  filepath = os.path.join(directory, filename)
123
-
124
- equation_length = random.randint(1, max_length)
125
-
126
  template = string.Template(latex.template)
127
  font, font_options = random.choice(latex.fonts)
128
  font_option = random.choice([''] + font_options)
@@ -130,70 +131,70 @@ def generate_image(directory: str, latex: DotDict, filename: str, max_length=20,
130
  equation = generate_equation(latex, equation_length, depth=equation_depth)
131
  tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
132
 
133
- files_before = set(os.listdir(directory))
134
  with open(f"{filepath}.tex", mode='w') as file:
135
  file.write(tex)
136
 
137
- pr1 = subprocess.run(
138
- f"{pdflatex} -output-directory={directory} {filepath}.tex".split(),
139
- stderr=subprocess.PIPE,
140
- )
 
 
 
 
 
 
141
 
142
- files_after = set(os.listdir(directory))
143
- if pr1.returncode != 0:
144
- files_to_delete = files_after - files_before
145
- if files_to_delete:
146
- subprocess.run(['rm'] + [os.path.join(directory, file) for file in files_to_delete])
147
- print(pr1.stderr.decode(), tex)
148
  return
149
 
150
- pr2 = subprocess.run(
151
  f"{ghostscript} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
152
- stderr=subprocess.PIPE,
 
153
  )
154
 
155
- files_to_delete = files_after - files_before - {filename + '.png', filename + '.tex'}
156
- if files_to_delete:
157
- subprocess.run(['rm'] + [os.path.join(directory, file) for file in files_to_delete])
158
- assert (pr2.returncode == 0)
159
 
 
 
160
 
161
- def generate_data(
162
- filenames: Iterable[str],
163
- directory: str,
164
- latex_path: str,
165
- overwrite: bool = False
166
- ) -> None:
167
  """
168
- Generates a latex dataset in given directory
169
  -------
170
  params:
171
- :filenames: - iterable of filenames to create, without extension
172
- :directory: - where to create
173
- :latex_path: - full path to latex json
174
- :overwrite: - whether to overwrite existing files
175
  """
176
- subprocess.run(". /external2/dkkoshman/venv/bin/activate")
177
- if not os.path.isabs(directory):
178
- directory = os.path.join(os.getcwd(), directory)
179
- if not os.path.isabs(latex_path):
180
- latex_path = os.path.join(os.getcwd(), latex_path)
181
-
182
- filenames = set(filenames)
183
- if not overwrite:
184
- existing = set(
185
- filename for file in os.listdir(directory) for filename, ext in os.path.splitext(file) if ext == '.png'
186
- )
187
- filenames -= existing
188
 
 
 
 
189
  with open(latex_path) as file:
190
  latex = json.load(file)
191
- latex = DotDict(latex)
192
 
 
 
 
 
 
 
 
 
193
  while filenames:
194
- for name in filenames:
195
- generate_image(directory, latex, name)
196
- # with Pool() as pool:
197
- # pool.starmap(generate_image, ((directory, latex, name) for name in filenames))
198
- existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
 
 
199
  filenames -= existing
 
 
 
 
 
 
 
 
1
+ from train import DATA_DIR, LATEX_PATH
2
+
3
  import json
4
  from multiprocessing import Pool
5
  import os
6
+ import shutil
7
  import string
8
  import subprocess
9
  import random
10
+ import tqdm
11
 
12
 
13
  class DotDict(dict):
 
105
  return equation
106
 
107
 
108
+ def generate_image(directory: str, latex: dict, filename: str, max_length=20, equation_depth=3,
109
  pdflatex: str = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex",
110
  ghostscript: str = "/external2/dkkoshman/venv/local/gs/bin/gs"
111
  ):
 
114
  -------
115
  params:
116
  :directory: -- dir where to save files
117
+ :latex: -- dict with parameters to generate tex
118
  :filename: -- absolute filename for the generated files
119
  :max_length: -- max size of equation
120
  :equation_depth: -- max nested level of tex scopes
121
  :pdflatex: -- path to pdflatex
122
  :ghostscript: -- path to ghostscript
123
  """
 
124
  filepath = os.path.join(directory, filename)
125
+ equation_length = random.randint(max_length // 2, max_length)
126
+ latex = DotDict(latex)
 
127
  template = string.Template(latex.template)
128
  font, font_options = random.choice(latex.fonts)
129
  font_option = random.choice([''] + font_options)
 
131
  equation = generate_equation(latex, equation_length, depth=equation_depth)
132
  tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
133
 
 
134
  with open(f"{filepath}.tex", mode='w') as file:
135
  file.write(tex)
136
 
137
+ try:
138
+ pdflatex_process = subprocess.run(
139
+ f"{pdflatex} -output-directory={directory} {filepath}.tex".split(),
140
+ stderr=subprocess.DEVNULL,
141
+ stdout=subprocess.DEVNULL,
142
+ timeout=1
143
+ )
144
+ except subprocess.TimeoutExpired:
145
+ subprocess.run(f'rm {filepath}.tex'.split())
146
+ return
147
 
148
+ if pdflatex_process.returncode != 0:
149
+ subprocess.run(f'rm {filepath}.tex'.split())
 
 
 
 
150
  return
151
 
152
+ subprocess.run(
153
  f"{ghostscript} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
154
+ stderr=subprocess.DEVNULL,
155
+ stdout=subprocess.DEVNULL,
156
  )
157
 
 
 
 
 
158
 
159
+ def _generate_image_wrapper(args):
160
+ return generate_image(*args)
161
 
162
+
163
+ def generate_data(examples_count) -> None:
 
 
 
 
164
  """
165
+ Clears a directory and generates a latex dataset in given directory
166
  -------
167
  params:
168
+ :examples_count: - how many latex - image examples to generate
 
 
 
169
  """
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count)),
172
+ directory = os.path.abspath(DATA_DIR)
173
+ latex_path = os.path.abspath(LATEX_PATH)
174
  with open(latex_path) as file:
175
  latex = json.load(file)
 
176
 
177
+ shutil.rmtree(directory)
178
+ os.mkdir(directory)
179
+
180
+ def _get_current_relevant_files():
181
+ return set(os.path.join(directory, file) for file in os.listdir(directory)) | set(
182
+ os.path.abspath(file) for file in os.listdir(os.getcwd()))
183
+
184
+ files_before = _get_current_relevant_files()
185
  while filenames:
186
+ with Pool() as pool:
187
+ list(tqdm.tqdm(
188
+ pool.imap(_generate_image_wrapper, ((directory, latex, filename) for filename in sorted(filenames))),
189
+ "Generating images",
190
+ total=len(filenames)
191
+ ))
192
+ existing = set(os.path.splitext(filename)[0] for filename in os.listdir(directory) if filename.endswith('.png'))
193
  filenames -= existing
194
+
195
+ files_after = _get_current_relevant_files()
196
+ files_to_delete = files_after - files_before
197
+ files_to_delete = list(os.path.join(directory, file) for file in files_to_delete if
198
+ not file.endswith('.png') and not file.endswith('.tex'))
199
+ if files_to_delete:
200
+ subprocess.run(['rm'] + files_to_delete)
data_preprocessing.py CHANGED
@@ -1,11 +1,15 @@
 
 
1
  import einops
2
  import os
 
3
  import tokenizers
4
  import torch
5
  import torchvision
6
  import torchvision.transforms as T
7
- from torch.utils.data import Dataset
8
  import tqdm
 
9
  import re
10
 
11
 
@@ -23,19 +27,10 @@ class TexImageDataset(Dataset):
23
  torch.multiprocessing.set_sharing_strategy('file_system')
24
  self.root_dir = root_dir
25
  self.filenames = sorted(set(
26
- filename for file in os.listdir(root_dir) for filename, ext in os.path.splitext(file) if ext == '.png'
27
  ))
28
  self.image_transform = image_transform
29
  self.tex_transform = tex_transform
30
- self.tex_tokenizer = None
31
- self.texs = []
32
- for filename in tqdm.tqdm(self.filenames, "Preloading tex files"):
33
- tex_path = os.path.join(self.root_dir, filename + '.tex')
34
- with open(tex_path) as file:
35
- tex = file.read()
36
- if self.tex_transform:
37
- tex = self.tex_transform(tex)
38
- self.texs.append(tex)
39
 
40
  def __len__(self):
41
  return len(self.filenames)
@@ -43,29 +38,34 @@ class TexImageDataset(Dataset):
43
  def __getitem__(self, idx):
44
  filename = self.filenames[idx]
45
  image_path = os.path.join(self.root_dir, filename + '.png')
 
 
 
 
 
 
 
46
  image = torchvision.io.read_image(image_path)
47
  if self.image_transform:
48
  image = self.image_transform(image)
49
- tex = self.texs[idx]
50
  return {"image": image, "tex": tex}
51
 
52
- def subjoin_image_normalize_transform(self):
53
- """Appends a normalize layer with mean and std computed after iterating over dataset"""
54
- mean = 0
55
- std = 0
56
- for item in tqdm.tqdm(self):
57
- image = item['image']
58
- mean += image.mean()
59
- std += image.std()
60
 
61
- mean /= len(self)
62
- std /= len(self)
63
- normalize = T.Normalize(mean, std)
64
 
65
- if self.image_transform:
66
- self.image_transform = T.Compose((self.image_transform, normalize))
67
- else:
68
- self.image_transform = normalize
 
 
 
 
 
 
 
69
 
70
 
71
  class BatchCollator(object):
@@ -138,10 +138,12 @@ class ExtractEquationFromTexTransform(object):
138
  return equation
139
 
140
 
141
- def generate_tex_tokenizer(texs, vocab_size=300):
142
- """Returns a tokenizer trained on given tex strings"""
143
 
144
- # os.environ['TOKENIZERS_PARALLELISM'] = 'false'
 
 
145
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
146
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
147
  vocab_size=vocab_size,
@@ -150,9 +152,43 @@ def generate_tex_tokenizer(texs, vocab_size=300):
150
  tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
151
  tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
152
  tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
153
- single="$A [SEP]",
154
- special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
 
 
 
155
  )
156
  tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
157
 
158
  return tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from train import DATASET_PATH, DATA_DIR, BATCH_SIZE, TEX_VOCAB_SIZE
2
+
3
  import einops
4
  import os
5
+ import pytorch_lightning as pl
6
  import tokenizers
7
  import torch
8
  import torchvision
9
  import torchvision.transforms as T
10
+ from torch.utils.data import Dataset, DataLoader
11
  import tqdm
12
+ from typing import Optional
13
  import re
14
 
15
 
 
27
  torch.multiprocessing.set_sharing_strategy('file_system')
28
  self.root_dir = root_dir
29
  self.filenames = sorted(set(
30
+ os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('.png')
31
  ))
32
  self.image_transform = image_transform
33
  self.tex_transform = tex_transform
 
 
 
 
 
 
 
 
 
34
 
35
  def __len__(self):
36
  return len(self.filenames)
 
38
  def __getitem__(self, idx):
39
  filename = self.filenames[idx]
40
  image_path = os.path.join(self.root_dir, filename + '.png')
41
+ tex_path = os.path.join(self.root_dir, filename + '.tex')
42
+
43
+ with open(tex_path) as file:
44
+ tex = file.read()
45
+ if self.tex_transform:
46
+ tex = self.tex_transform(tex)
47
+
48
  image = torchvision.io.read_image(image_path)
49
  if self.image_transform:
50
  image = self.image_transform(image)
51
+
52
  return {"image": image, "tex": tex}
53
 
 
 
 
 
 
 
 
 
54
 
55
+ def generate_normalize_transform(dataset: TexImageDataset):
56
+ """Returns a normalize layer with mean and std computed after iterating over dataset"""
 
57
 
58
+ mean = 0
59
+ std = 0
60
+ for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
61
+ image = item['image']
62
+ mean += image.mean()
63
+ std += image.std()
64
+
65
+ mean /= len(dataset)
66
+ std /= len(dataset)
67
+ normalize = T.Normalize(mean, std)
68
+ return normalize
69
 
70
 
71
  class BatchCollator(object):
 
138
  return equation
139
 
140
 
141
+ def generate_tex_tokenizer(dataset: TexImageDataset, vocab_size=300):
142
+ """Returns a tokenizer trained on texs from given dataset"""
143
 
144
+ texs = list(tqdm.tqdm((item['tex'] for item in dataset), "Training tokenizer"))
145
+
146
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
147
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
148
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
149
  vocab_size=vocab_size,
 
152
  tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
153
  tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
154
  tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
155
+ single="[CLS] $A [SEP]",
156
+ special_tokens=[
157
+ ("[CLS]", tokenizer.token_to_id("[CLS]")),
158
+ ("[SEP]", tokenizer.token_to_id("[SEP]")),
159
+ ]
160
  )
161
  tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
162
 
163
  return tokenizer
164
+
165
+
166
+ class LatexImageDataModule(pl.LightningDataModule):
167
+ def prepare_data(self) -> None:
168
+ # download or something
169
+ ...
170
+
171
+ def setup(self, stage: Optional[str] = None) -> None:
172
+ tex_transform = ExtractEquationFromTexTransform()
173
+ dataset = TexImageDataset(DATA_DIR, tex_transform=tex_transform)
174
+
175
+ self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
176
+ dataset,
177
+ [len(dataset) - 2 * len(dataset) // 10, len(dataset) // 10, len(dataset) // 10]
178
+ )
179
+ self.train_dataset.image_transform = RandomizeImageTransform()
180
+ self.val_dataset.image_transform = StandardizeImageTransform()
181
+ self.test_dataset.image_transform = StandardizeImageTransform()
182
+ # image_normalize = generate_normalize_transform(self.train_dataset), compose?
183
+
184
+ self.tex_tokenizer = generate_tex_tokenizer(self.train_dataset, vocab_size=TEX_VOCAB_SIZE)
185
+ self.collate_fn = BatchCollator(self.tex_tokenizer)
186
+
187
+ def train_dataloader(self):
188
+ return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=self.collate_fn)
189
+
190
+ def val_dataloader(self):
191
+ return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=self.collate_fn)
192
+
193
+ def test_dataloader(self):
194
+ return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=self.collate_fn)
model.py CHANGED
@@ -1,6 +1,8 @@
1
  from einops.layers.torch import Rearrange
2
  import einops
3
  import math
 
 
4
  import torch.nn as nn
5
  import torch
6
 
@@ -62,6 +64,21 @@ class ImageEmbedding(nn.Module):
62
  return image_batch
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  class ImageEncoder(nn.Module):
66
  """
67
  Given an image, returns its vector representation.
@@ -83,7 +100,10 @@ class ImageEncoder(nn.Module):
83
  return self.encode(batch)
84
 
85
 
86
- class Seq2SeqTransformer(nn.Module):
 
 
 
87
  def __init__(self,
88
  num_encoder_layers: int,
89
  num_decoder_layers: int,
@@ -92,39 +112,47 @@ class Seq2SeqTransformer(nn.Module):
92
  image_width: int,
93
  image_height: int,
94
  tgt_vocab_size: int,
 
95
  dim_feedforward: int = 512,
96
- dropout: float = 0.1):
97
- super(Seq2SeqTransformer, self).__init__()
98
  self.transformer = nn.Transformer(d_model=emb_size,
99
  nhead=nhead,
100
  num_encoder_layers=num_encoder_layers,
101
  num_decoder_layers=num_decoder_layers,
102
  dim_feedforward=dim_feedforward,
103
- dropout=dropout)
104
- # TODO: share weights between generator and embedding
 
 
 
 
105
  self.generator = nn.Linear(emb_size, tgt_vocab_size)
106
  self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
107
- self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
108
-
109
- def forward(self,
110
- src: Tensor,
111
- trg: Tensor,
112
- src_mask: Tensor,
113
- tgt_mask: Tensor,
114
- src_padding_mask: Tensor,
115
- tgt_padding_mask: Tensor,
116
- memory_key_padding_mask: Tensor):
117
- src_emb = self.positional_encoding(self.src_tok_emb(src))
118
- tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
119
  outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
120
- src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
121
  return self.generator(outs)
122
 
123
- def encode(self, src: Tensor, src_mask: Tensor):
124
- return self.transformer.encoder(self.positional_encoding(
125
- self.src_tok_emb(src)), src_mask)
126
-
127
- def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
128
- return self.transformer.decoder(self.positional_encoding(
129
- self.tgt_tok_emb(tgt)), memory,
130
- tgt_mask)
 
 
 
 
 
 
 
 
 
 
1
  from einops.layers.torch import Rearrange
2
  import einops
3
  import math
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
6
  import torch.nn as nn
7
  import torch
8
 
 
64
  return image_batch
65
 
66
 
67
+ class TexEmbedding(nn.Module):
68
+ def __init__(self, d_model: int, vocab_size: int, dropout: float = .1):
69
+ super().__init__()
70
+ self.embedding = nn.Embedding(vocab_size, d_model)
71
+ self.add_positional_encoding = AddPositionalEncoding(d_model)
72
+ self.dropout = nn.Dropout(p=dropout)
73
+ self.d_model = d_model
74
+
75
+ def forward(self, tex_ids_batch):
76
+ tex_ids_batch = self.embedding(tex_ids_batch.long()) * math.sqrt(self.d_model)
77
+ tex_ids_batch = self.add_positional_encoding(tex_ids_batch)
78
+ tex_ids_batch = self.dropout(tex_ids_batch)
79
+ return tex_ids_batch
80
+
81
+
82
  class ImageEncoder(nn.Module):
83
  """
84
  Given an image, returns its vector representation.
 
100
  return self.encode(batch)
101
 
102
 
103
+ class Transformer(pl.LightningModule):
104
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
105
+ pass
106
+
107
  def __init__(self,
108
  num_encoder_layers: int,
109
  num_decoder_layers: int,
 
112
  image_width: int,
113
  image_height: int,
114
  tgt_vocab_size: int,
115
+ pad_idx: int,
116
  dim_feedforward: int = 512,
117
+ dropout: float = .1):
118
+ super().__init__()
119
  self.transformer = nn.Transformer(d_model=emb_size,
120
  nhead=nhead,
121
  num_encoder_layers=num_encoder_layers,
122
  num_decoder_layers=num_decoder_layers,
123
  dim_feedforward=dim_feedforward,
124
+ dropout=dropout,
125
+ batch_first=True)
126
+ for p in self.transformer.parameters():
127
+ if p.dim() > 1:
128
+ nn.init.xavier_uniform_(p)
129
+
130
  self.generator = nn.Linear(emb_size, tgt_vocab_size)
131
  self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
132
+ self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
133
+ self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
134
+
135
+ def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask):
136
+ src_emb = self.src_tok_emb(src)
137
+ tgt_emb = self.tgt_tok_emb(tgt)
 
 
 
 
 
 
138
  outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
139
+ src_padding_mask, tgt_padding_mask)
140
  return self.generator(outs)
141
 
142
+ def training_step(self, batch, batch_idx):
143
+ src = batch['images']
144
+ tgt = batch['tex_ids']
145
+ tgt_input = tgt[:, :-1]
146
+ tgt_output = tgt[:, 1:]
147
+ src_mask = None
148
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
149
+ torch.ByteTensor.dtype)
150
+ src_padding_mask = None
151
+ tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
152
+ outs = self(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
153
+ loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
154
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
155
+ return loss
156
+
157
+ def configure_optimizers(self):
158
+ return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
train.py CHANGED
@@ -1,32 +1,54 @@
1
  from data_generator import generate_data
2
- from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
3
- BatchCollator, generate_tex_tokenizer
4
 
 
 
 
5
  import torch
6
- from torch.utils.data import DataLoader
7
 
8
  DATA_DIR = 'data'
9
  LATEX_PATH = 'resources/latex.json'
 
 
 
 
 
10
 
11
- if __name__ == '__main__':
12
- generate_data(
13
- filenames=map(str, range(1000)),
14
- directory=DATA_DIR,
15
- latex_path=LATEX_PATH,
16
- )
17
 
18
- image_transform = RandomizeImageTransform()
19
- tex_transform = ExtractEquationFromTexTransform()
20
- dataset = TexImageDataset(DATA_DIR, image_transform=image_transform, tex_transform=tex_transform)
21
- dataset.subjoin_image_normalize_transform()
22
- train_dataset, test_dataset = torch.utils.data.random_split(
23
- dataset,
24
- [len(dataset) * 9 // 10, len(dataset) // 10]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
- tex_tokenizer = generate_tex_tokenizer(dataset.texs)
27
- collate_fn = BatchCollator(tex_tokenizer)
28
 
29
- train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
30
- collate_fn=collate_fn)
31
- batch = next(iter(train_dataloader))
32
- print(batch['texs'])
 
 
 
 
1
  from data_generator import generate_data
2
+ from data_preprocessing import LatexImageDataModule
3
+ from model import Transformer
4
 
5
+ import argparse
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning.loggers import WandbLogger
8
  import torch
 
9
 
10
  DATA_DIR = 'data'
11
  LATEX_PATH = 'resources/latex.json'
12
+ DATASET_PATH = 'resources/dataset'
13
+ IMAGE_WIDTH = 1024
14
+ IMAGE_HEIGHT = 128
15
+ TEX_VOCAB_SIZE = 300
16
+ BATCH_SIZE = 16
17
 
 
 
 
 
 
 
18
 
19
+ def main():
20
+ torch.manual_seed(0)
21
+
22
+ parser = argparse.ArgumentParser("Trainer")
23
+ parser.add_argument("-generate-new", help="number of new files to generate", type=int)
24
+ args = parser.parse_args()
25
+
26
+ if args.generate_new is not None:
27
+ generate_data(args.generate_new)
28
+ datamodule = LatexImageDataModule()
29
+ torch.save(datamodule, DATASET_PATH)
30
+ else:
31
+ datamodule = torch.load(DATASET_PATH)
32
+
33
+ wandb_logger = WandbLogger()
34
+ trainer = pl.Trainer(max_epochs=2, accelerator='gpu', gpus=1, logger=wandb_logger)
35
+ transformer = Transformer(
36
+ num_encoder_layers=3,
37
+ num_decoder_layers=3,
38
+ emb_size=512,
39
+ nhead=8,
40
+ image_width=IMAGE_WIDTH,
41
+ image_height=IMAGE_HEIGHT,
42
+ tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
43
+ pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
44
+ dim_feedforward=512,
45
+ dropout=0.1
46
  )
 
 
47
 
48
+ trainer.fit(transformer, datamodule=datamodule)
49
+ trainer.validate(datamodule=datamodule)
50
+ trainer.test(datamodule=datamodule)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ main()