Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
·
fb8db0f
1
Parent(s):
8ab1767
big changes: transformer, pytorch lightning, argparser
Browse files- data_generator.py +54 -53
- data_preprocessing.py +68 -32
- model.py +54 -26
- train.py +44 -22
data_generator.py
CHANGED
@@ -1,10 +1,13 @@
|
|
|
|
|
|
1 |
import json
|
2 |
from multiprocessing import Pool
|
3 |
import os
|
|
|
4 |
import string
|
5 |
import subprocess
|
6 |
import random
|
7 |
-
|
8 |
|
9 |
|
10 |
class DotDict(dict):
|
@@ -102,7 +105,7 @@ def generate_equation(latex: DotDict, size, depth=3):
|
|
102 |
return equation
|
103 |
|
104 |
|
105 |
-
def generate_image(directory: str, latex:
|
106 |
pdflatex: str = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex",
|
107 |
ghostscript: str = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
108 |
):
|
@@ -111,18 +114,16 @@ def generate_image(directory: str, latex: DotDict, filename: str, max_length=20,
|
|
111 |
-------
|
112 |
params:
|
113 |
:directory: -- dir where to save files
|
114 |
-
:latex: --
|
115 |
:filename: -- absolute filename for the generated files
|
116 |
:max_length: -- max size of equation
|
117 |
:equation_depth: -- max nested level of tex scopes
|
118 |
:pdflatex: -- path to pdflatex
|
119 |
:ghostscript: -- path to ghostscript
|
120 |
"""
|
121 |
-
# TODO ARGPARSE
|
122 |
filepath = os.path.join(directory, filename)
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
template = string.Template(latex.template)
|
127 |
font, font_options = random.choice(latex.fonts)
|
128 |
font_option = random.choice([''] + font_options)
|
@@ -130,70 +131,70 @@ def generate_image(directory: str, latex: DotDict, filename: str, max_length=20,
|
|
130 |
equation = generate_equation(latex, equation_length, depth=equation_depth)
|
131 |
tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
|
132 |
|
133 |
-
files_before = set(os.listdir(directory))
|
134 |
with open(f"{filepath}.tex", mode='w') as file:
|
135 |
file.write(tex)
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
files_to_delete = files_after - files_before
|
145 |
-
if files_to_delete:
|
146 |
-
subprocess.run(['rm'] + [os.path.join(directory, file) for file in files_to_delete])
|
147 |
-
print(pr1.stderr.decode(), tex)
|
148 |
return
|
149 |
|
150 |
-
|
151 |
f"{ghostscript} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
|
152 |
-
stderr=subprocess.
|
|
|
153 |
)
|
154 |
|
155 |
-
files_to_delete = files_after - files_before - {filename + '.png', filename + '.tex'}
|
156 |
-
if files_to_delete:
|
157 |
-
subprocess.run(['rm'] + [os.path.join(directory, file) for file in files_to_delete])
|
158 |
-
assert (pr2.returncode == 0)
|
159 |
|
|
|
|
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
directory: str,
|
164 |
-
latex_path: str,
|
165 |
-
overwrite: bool = False
|
166 |
-
) -> None:
|
167 |
"""
|
168 |
-
|
169 |
-------
|
170 |
params:
|
171 |
-
:
|
172 |
-
:directory: - where to create
|
173 |
-
:latex_path: - full path to latex json
|
174 |
-
:overwrite: - whether to overwrite existing files
|
175 |
"""
|
176 |
-
subprocess.run(". /external2/dkkoshman/venv/bin/activate")
|
177 |
-
if not os.path.isabs(directory):
|
178 |
-
directory = os.path.join(os.getcwd(), directory)
|
179 |
-
if not os.path.isabs(latex_path):
|
180 |
-
latex_path = os.path.join(os.getcwd(), latex_path)
|
181 |
-
|
182 |
-
filenames = set(filenames)
|
183 |
-
if not overwrite:
|
184 |
-
existing = set(
|
185 |
-
filename for file in os.listdir(directory) for filename, ext in os.path.splitext(file) if ext == '.png'
|
186 |
-
)
|
187 |
-
filenames -= existing
|
188 |
|
|
|
|
|
|
|
189 |
with open(latex_path) as file:
|
190 |
latex = json.load(file)
|
191 |
-
latex = DotDict(latex)
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
while filenames:
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
199 |
filenames -= existing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from train import DATA_DIR, LATEX_PATH
|
2 |
+
|
3 |
import json
|
4 |
from multiprocessing import Pool
|
5 |
import os
|
6 |
+
import shutil
|
7 |
import string
|
8 |
import subprocess
|
9 |
import random
|
10 |
+
import tqdm
|
11 |
|
12 |
|
13 |
class DotDict(dict):
|
|
|
105 |
return equation
|
106 |
|
107 |
|
108 |
+
def generate_image(directory: str, latex: dict, filename: str, max_length=20, equation_depth=3,
|
109 |
pdflatex: str = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex",
|
110 |
ghostscript: str = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
111 |
):
|
|
|
114 |
-------
|
115 |
params:
|
116 |
:directory: -- dir where to save files
|
117 |
+
:latex: -- dict with parameters to generate tex
|
118 |
:filename: -- absolute filename for the generated files
|
119 |
:max_length: -- max size of equation
|
120 |
:equation_depth: -- max nested level of tex scopes
|
121 |
:pdflatex: -- path to pdflatex
|
122 |
:ghostscript: -- path to ghostscript
|
123 |
"""
|
|
|
124 |
filepath = os.path.join(directory, filename)
|
125 |
+
equation_length = random.randint(max_length // 2, max_length)
|
126 |
+
latex = DotDict(latex)
|
|
|
127 |
template = string.Template(latex.template)
|
128 |
font, font_options = random.choice(latex.fonts)
|
129 |
font_option = random.choice([''] + font_options)
|
|
|
131 |
equation = generate_equation(latex, equation_length, depth=equation_depth)
|
132 |
tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
|
133 |
|
|
|
134 |
with open(f"{filepath}.tex", mode='w') as file:
|
135 |
file.write(tex)
|
136 |
|
137 |
+
try:
|
138 |
+
pdflatex_process = subprocess.run(
|
139 |
+
f"{pdflatex} -output-directory={directory} {filepath}.tex".split(),
|
140 |
+
stderr=subprocess.DEVNULL,
|
141 |
+
stdout=subprocess.DEVNULL,
|
142 |
+
timeout=1
|
143 |
+
)
|
144 |
+
except subprocess.TimeoutExpired:
|
145 |
+
subprocess.run(f'rm {filepath}.tex'.split())
|
146 |
+
return
|
147 |
|
148 |
+
if pdflatex_process.returncode != 0:
|
149 |
+
subprocess.run(f'rm {filepath}.tex'.split())
|
|
|
|
|
|
|
|
|
150 |
return
|
151 |
|
152 |
+
subprocess.run(
|
153 |
f"{ghostscript} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
|
154 |
+
stderr=subprocess.DEVNULL,
|
155 |
+
stdout=subprocess.DEVNULL,
|
156 |
)
|
157 |
|
|
|
|
|
|
|
|
|
158 |
|
159 |
+
def _generate_image_wrapper(args):
|
160 |
+
return generate_image(*args)
|
161 |
|
162 |
+
|
163 |
+
def generate_data(examples_count) -> None:
|
|
|
|
|
|
|
|
|
164 |
"""
|
165 |
+
Clears a directory and generates a latex dataset in given directory
|
166 |
-------
|
167 |
params:
|
168 |
+
:examples_count: - how many latex - image examples to generate
|
|
|
|
|
|
|
169 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count)),
|
172 |
+
directory = os.path.abspath(DATA_DIR)
|
173 |
+
latex_path = os.path.abspath(LATEX_PATH)
|
174 |
with open(latex_path) as file:
|
175 |
latex = json.load(file)
|
|
|
176 |
|
177 |
+
shutil.rmtree(directory)
|
178 |
+
os.mkdir(directory)
|
179 |
+
|
180 |
+
def _get_current_relevant_files():
|
181 |
+
return set(os.path.join(directory, file) for file in os.listdir(directory)) | set(
|
182 |
+
os.path.abspath(file) for file in os.listdir(os.getcwd()))
|
183 |
+
|
184 |
+
files_before = _get_current_relevant_files()
|
185 |
while filenames:
|
186 |
+
with Pool() as pool:
|
187 |
+
list(tqdm.tqdm(
|
188 |
+
pool.imap(_generate_image_wrapper, ((directory, latex, filename) for filename in sorted(filenames))),
|
189 |
+
"Generating images",
|
190 |
+
total=len(filenames)
|
191 |
+
))
|
192 |
+
existing = set(os.path.splitext(filename)[0] for filename in os.listdir(directory) if filename.endswith('.png'))
|
193 |
filenames -= existing
|
194 |
+
|
195 |
+
files_after = _get_current_relevant_files()
|
196 |
+
files_to_delete = files_after - files_before
|
197 |
+
files_to_delete = list(os.path.join(directory, file) for file in files_to_delete if
|
198 |
+
not file.endswith('.png') and not file.endswith('.tex'))
|
199 |
+
if files_to_delete:
|
200 |
+
subprocess.run(['rm'] + files_to_delete)
|
data_preprocessing.py
CHANGED
@@ -1,11 +1,15 @@
|
|
|
|
|
|
1 |
import einops
|
2 |
import os
|
|
|
3 |
import tokenizers
|
4 |
import torch
|
5 |
import torchvision
|
6 |
import torchvision.transforms as T
|
7 |
-
from torch.utils.data import Dataset
|
8 |
import tqdm
|
|
|
9 |
import re
|
10 |
|
11 |
|
@@ -23,19 +27,10 @@ class TexImageDataset(Dataset):
|
|
23 |
torch.multiprocessing.set_sharing_strategy('file_system')
|
24 |
self.root_dir = root_dir
|
25 |
self.filenames = sorted(set(
|
26 |
-
filename for
|
27 |
))
|
28 |
self.image_transform = image_transform
|
29 |
self.tex_transform = tex_transform
|
30 |
-
self.tex_tokenizer = None
|
31 |
-
self.texs = []
|
32 |
-
for filename in tqdm.tqdm(self.filenames, "Preloading tex files"):
|
33 |
-
tex_path = os.path.join(self.root_dir, filename + '.tex')
|
34 |
-
with open(tex_path) as file:
|
35 |
-
tex = file.read()
|
36 |
-
if self.tex_transform:
|
37 |
-
tex = self.tex_transform(tex)
|
38 |
-
self.texs.append(tex)
|
39 |
|
40 |
def __len__(self):
|
41 |
return len(self.filenames)
|
@@ -43,29 +38,34 @@ class TexImageDataset(Dataset):
|
|
43 |
def __getitem__(self, idx):
|
44 |
filename = self.filenames[idx]
|
45 |
image_path = os.path.join(self.root_dir, filename + '.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
image = torchvision.io.read_image(image_path)
|
47 |
if self.image_transform:
|
48 |
image = self.image_transform(image)
|
49 |
-
|
50 |
return {"image": image, "tex": tex}
|
51 |
|
52 |
-
def subjoin_image_normalize_transform(self):
|
53 |
-
"""Appends a normalize layer with mean and std computed after iterating over dataset"""
|
54 |
-
mean = 0
|
55 |
-
std = 0
|
56 |
-
for item in tqdm.tqdm(self):
|
57 |
-
image = item['image']
|
58 |
-
mean += image.mean()
|
59 |
-
std += image.std()
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
normalize = T.Normalize(mean, std)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
|
71 |
class BatchCollator(object):
|
@@ -138,10 +138,12 @@ class ExtractEquationFromTexTransform(object):
|
|
138 |
return equation
|
139 |
|
140 |
|
141 |
-
def generate_tex_tokenizer(
|
142 |
-
"""Returns a tokenizer trained on given
|
143 |
|
144 |
-
|
|
|
|
|
145 |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
146 |
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
147 |
vocab_size=vocab_size,
|
@@ -150,9 +152,43 @@ def generate_tex_tokenizer(texs, vocab_size=300):
|
|
150 |
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
151 |
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
|
152 |
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
|
153 |
-
single="$A [SEP]",
|
154 |
-
special_tokens=[
|
|
|
|
|
|
|
155 |
)
|
156 |
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
|
157 |
|
158 |
return tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from train import DATASET_PATH, DATA_DIR, BATCH_SIZE, TEX_VOCAB_SIZE
|
2 |
+
|
3 |
import einops
|
4 |
import os
|
5 |
+
import pytorch_lightning as pl
|
6 |
import tokenizers
|
7 |
import torch
|
8 |
import torchvision
|
9 |
import torchvision.transforms as T
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
import tqdm
|
12 |
+
from typing import Optional
|
13 |
import re
|
14 |
|
15 |
|
|
|
27 |
torch.multiprocessing.set_sharing_strategy('file_system')
|
28 |
self.root_dir = root_dir
|
29 |
self.filenames = sorted(set(
|
30 |
+
os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('.png')
|
31 |
))
|
32 |
self.image_transform = image_transform
|
33 |
self.tex_transform = tex_transform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def __len__(self):
|
36 |
return len(self.filenames)
|
|
|
38 |
def __getitem__(self, idx):
|
39 |
filename = self.filenames[idx]
|
40 |
image_path = os.path.join(self.root_dir, filename + '.png')
|
41 |
+
tex_path = os.path.join(self.root_dir, filename + '.tex')
|
42 |
+
|
43 |
+
with open(tex_path) as file:
|
44 |
+
tex = file.read()
|
45 |
+
if self.tex_transform:
|
46 |
+
tex = self.tex_transform(tex)
|
47 |
+
|
48 |
image = torchvision.io.read_image(image_path)
|
49 |
if self.image_transform:
|
50 |
image = self.image_transform(image)
|
51 |
+
|
52 |
return {"image": image, "tex": tex}
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
def generate_normalize_transform(dataset: TexImageDataset):
|
56 |
+
"""Returns a normalize layer with mean and std computed after iterating over dataset"""
|
|
|
57 |
|
58 |
+
mean = 0
|
59 |
+
std = 0
|
60 |
+
for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
|
61 |
+
image = item['image']
|
62 |
+
mean += image.mean()
|
63 |
+
std += image.std()
|
64 |
+
|
65 |
+
mean /= len(dataset)
|
66 |
+
std /= len(dataset)
|
67 |
+
normalize = T.Normalize(mean, std)
|
68 |
+
return normalize
|
69 |
|
70 |
|
71 |
class BatchCollator(object):
|
|
|
138 |
return equation
|
139 |
|
140 |
|
141 |
+
def generate_tex_tokenizer(dataset: TexImageDataset, vocab_size=300):
|
142 |
+
"""Returns a tokenizer trained on texs from given dataset"""
|
143 |
|
144 |
+
texs = list(tqdm.tqdm((item['tex'] for item in dataset), "Training tokenizer"))
|
145 |
+
|
146 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
147 |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
148 |
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
149 |
vocab_size=vocab_size,
|
|
|
152 |
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
153 |
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
|
154 |
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
|
155 |
+
single="[CLS] $A [SEP]",
|
156 |
+
special_tokens=[
|
157 |
+
("[CLS]", tokenizer.token_to_id("[CLS]")),
|
158 |
+
("[SEP]", tokenizer.token_to_id("[SEP]")),
|
159 |
+
]
|
160 |
)
|
161 |
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
|
162 |
|
163 |
return tokenizer
|
164 |
+
|
165 |
+
|
166 |
+
class LatexImageDataModule(pl.LightningDataModule):
|
167 |
+
def prepare_data(self) -> None:
|
168 |
+
# download or something
|
169 |
+
...
|
170 |
+
|
171 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
172 |
+
tex_transform = ExtractEquationFromTexTransform()
|
173 |
+
dataset = TexImageDataset(DATA_DIR, tex_transform=tex_transform)
|
174 |
+
|
175 |
+
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
|
176 |
+
dataset,
|
177 |
+
[len(dataset) - 2 * len(dataset) // 10, len(dataset) // 10, len(dataset) // 10]
|
178 |
+
)
|
179 |
+
self.train_dataset.image_transform = RandomizeImageTransform()
|
180 |
+
self.val_dataset.image_transform = StandardizeImageTransform()
|
181 |
+
self.test_dataset.image_transform = StandardizeImageTransform()
|
182 |
+
# image_normalize = generate_normalize_transform(self.train_dataset), compose?
|
183 |
+
|
184 |
+
self.tex_tokenizer = generate_tex_tokenizer(self.train_dataset, vocab_size=TEX_VOCAB_SIZE)
|
185 |
+
self.collate_fn = BatchCollator(self.tex_tokenizer)
|
186 |
+
|
187 |
+
def train_dataloader(self):
|
188 |
+
return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=self.collate_fn)
|
189 |
+
|
190 |
+
def val_dataloader(self):
|
191 |
+
return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=self.collate_fn)
|
192 |
+
|
193 |
+
def test_dataloader(self):
|
194 |
+
return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=self.collate_fn)
|
model.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from einops.layers.torch import Rearrange
|
2 |
import einops
|
3 |
import math
|
|
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch
|
6 |
|
@@ -62,6 +64,21 @@ class ImageEmbedding(nn.Module):
|
|
62 |
return image_batch
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
class ImageEncoder(nn.Module):
|
66 |
"""
|
67 |
Given an image, returns its vector representation.
|
@@ -83,7 +100,10 @@ class ImageEncoder(nn.Module):
|
|
83 |
return self.encode(batch)
|
84 |
|
85 |
|
86 |
-
class
|
|
|
|
|
|
|
87 |
def __init__(self,
|
88 |
num_encoder_layers: int,
|
89 |
num_decoder_layers: int,
|
@@ -92,39 +112,47 @@ class Seq2SeqTransformer(nn.Module):
|
|
92 |
image_width: int,
|
93 |
image_height: int,
|
94 |
tgt_vocab_size: int,
|
|
|
95 |
dim_feedforward: int = 512,
|
96 |
-
dropout: float =
|
97 |
-
super(
|
98 |
self.transformer = nn.Transformer(d_model=emb_size,
|
99 |
nhead=nhead,
|
100 |
num_encoder_layers=num_encoder_layers,
|
101 |
num_decoder_layers=num_decoder_layers,
|
102 |
dim_feedforward=dim_feedforward,
|
103 |
-
dropout=dropout
|
104 |
-
|
|
|
|
|
|
|
|
|
105 |
self.generator = nn.Linear(emb_size, tgt_vocab_size)
|
106 |
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
|
107 |
-
self.tgt_tok_emb =
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
tgt_mask: Tensor,
|
114 |
-
src_padding_mask: Tensor,
|
115 |
-
tgt_padding_mask: Tensor,
|
116 |
-
memory_key_padding_mask: Tensor):
|
117 |
-
src_emb = self.positional_encoding(self.src_tok_emb(src))
|
118 |
-
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
|
119 |
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
|
120 |
-
src_padding_mask, tgt_padding_mask
|
121 |
return self.generator(outs)
|
122 |
|
123 |
-
def
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from einops.layers.torch import Rearrange
|
2 |
import einops
|
3 |
import math
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
|
6 |
import torch.nn as nn
|
7 |
import torch
|
8 |
|
|
|
64 |
return image_batch
|
65 |
|
66 |
|
67 |
+
class TexEmbedding(nn.Module):
|
68 |
+
def __init__(self, d_model: int, vocab_size: int, dropout: float = .1):
|
69 |
+
super().__init__()
|
70 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
71 |
+
self.add_positional_encoding = AddPositionalEncoding(d_model)
|
72 |
+
self.dropout = nn.Dropout(p=dropout)
|
73 |
+
self.d_model = d_model
|
74 |
+
|
75 |
+
def forward(self, tex_ids_batch):
|
76 |
+
tex_ids_batch = self.embedding(tex_ids_batch.long()) * math.sqrt(self.d_model)
|
77 |
+
tex_ids_batch = self.add_positional_encoding(tex_ids_batch)
|
78 |
+
tex_ids_batch = self.dropout(tex_ids_batch)
|
79 |
+
return tex_ids_batch
|
80 |
+
|
81 |
+
|
82 |
class ImageEncoder(nn.Module):
|
83 |
"""
|
84 |
Given an image, returns its vector representation.
|
|
|
100 |
return self.encode(batch)
|
101 |
|
102 |
|
103 |
+
class Transformer(pl.LightningModule):
|
104 |
+
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
105 |
+
pass
|
106 |
+
|
107 |
def __init__(self,
|
108 |
num_encoder_layers: int,
|
109 |
num_decoder_layers: int,
|
|
|
112 |
image_width: int,
|
113 |
image_height: int,
|
114 |
tgt_vocab_size: int,
|
115 |
+
pad_idx: int,
|
116 |
dim_feedforward: int = 512,
|
117 |
+
dropout: float = .1):
|
118 |
+
super().__init__()
|
119 |
self.transformer = nn.Transformer(d_model=emb_size,
|
120 |
nhead=nhead,
|
121 |
num_encoder_layers=num_encoder_layers,
|
122 |
num_decoder_layers=num_decoder_layers,
|
123 |
dim_feedforward=dim_feedforward,
|
124 |
+
dropout=dropout,
|
125 |
+
batch_first=True)
|
126 |
+
for p in self.transformer.parameters():
|
127 |
+
if p.dim() > 1:
|
128 |
+
nn.init.xavier_uniform_(p)
|
129 |
+
|
130 |
self.generator = nn.Linear(emb_size, tgt_vocab_size)
|
131 |
self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, dropout=dropout)
|
132 |
+
self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
|
133 |
+
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
|
134 |
+
|
135 |
+
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask):
|
136 |
+
src_emb = self.src_tok_emb(src)
|
137 |
+
tgt_emb = self.tgt_tok_emb(tgt)
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
|
139 |
+
src_padding_mask, tgt_padding_mask)
|
140 |
return self.generator(outs)
|
141 |
|
142 |
+
def training_step(self, batch, batch_idx):
|
143 |
+
src = batch['images']
|
144 |
+
tgt = batch['tex_ids']
|
145 |
+
tgt_input = tgt[:, :-1]
|
146 |
+
tgt_output = tgt[:, 1:]
|
147 |
+
src_mask = None
|
148 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
149 |
+
torch.ByteTensor.dtype)
|
150 |
+
src_padding_mask = None
|
151 |
+
tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
152 |
+
outs = self(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
|
153 |
+
loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
|
154 |
+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
155 |
+
return loss
|
156 |
+
|
157 |
+
def configure_optimizers(self):
|
158 |
+
return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
|
train.py
CHANGED
@@ -1,32 +1,54 @@
|
|
1 |
from data_generator import generate_data
|
2 |
-
from data_preprocessing import
|
3 |
-
|
4 |
|
|
|
|
|
|
|
5 |
import torch
|
6 |
-
from torch.utils.data import DataLoader
|
7 |
|
8 |
DATA_DIR = 'data'
|
9 |
LATEX_PATH = 'resources/latex.json'
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
if __name__ == '__main__':
|
12 |
-
generate_data(
|
13 |
-
filenames=map(str, range(1000)),
|
14 |
-
directory=DATA_DIR,
|
15 |
-
latex_path=LATEX_PATH,
|
16 |
-
)
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
)
|
26 |
-
tex_tokenizer = generate_tex_tokenizer(dataset.texs)
|
27 |
-
collate_fn = BatchCollator(tex_tokenizer)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
1 |
from data_generator import generate_data
|
2 |
+
from data_preprocessing import LatexImageDataModule
|
3 |
+
from model import Transformer
|
4 |
|
5 |
+
import argparse
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from pytorch_lightning.loggers import WandbLogger
|
8 |
import torch
|
|
|
9 |
|
10 |
DATA_DIR = 'data'
|
11 |
LATEX_PATH = 'resources/latex.json'
|
12 |
+
DATASET_PATH = 'resources/dataset'
|
13 |
+
IMAGE_WIDTH = 1024
|
14 |
+
IMAGE_HEIGHT = 128
|
15 |
+
TEX_VOCAB_SIZE = 300
|
16 |
+
BATCH_SIZE = 16
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
def main():
|
20 |
+
torch.manual_seed(0)
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser("Trainer")
|
23 |
+
parser.add_argument("-generate-new", help="number of new files to generate", type=int)
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
if args.generate_new is not None:
|
27 |
+
generate_data(args.generate_new)
|
28 |
+
datamodule = LatexImageDataModule()
|
29 |
+
torch.save(datamodule, DATASET_PATH)
|
30 |
+
else:
|
31 |
+
datamodule = torch.load(DATASET_PATH)
|
32 |
+
|
33 |
+
wandb_logger = WandbLogger()
|
34 |
+
trainer = pl.Trainer(max_epochs=2, accelerator='gpu', gpus=1, logger=wandb_logger)
|
35 |
+
transformer = Transformer(
|
36 |
+
num_encoder_layers=3,
|
37 |
+
num_decoder_layers=3,
|
38 |
+
emb_size=512,
|
39 |
+
nhead=8,
|
40 |
+
image_width=IMAGE_WIDTH,
|
41 |
+
image_height=IMAGE_HEIGHT,
|
42 |
+
tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
|
43 |
+
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
|
44 |
+
dim_feedforward=512,
|
45 |
+
dropout=0.1
|
46 |
)
|
|
|
|
|
47 |
|
48 |
+
trainer.fit(transformer, datamodule=datamodule)
|
49 |
+
trainer.validate(datamodule=datamodule)
|
50 |
+
trainer.test(datamodule=datamodule)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
main()
|