Spaces:
Runtime error
Runtime error
import json | |
from multiprocessing import Pool | |
import os | |
import shutil | |
import string | |
import subprocess | |
import random | |
import tqdm | |
DATA_DIR = 'data' | |
LATEX_PATH = 'resources/latex.json' | |
class DotDict(dict): | |
"""dot.notation access to dictionary attributes""" | |
__getattr__ = dict.get | |
__setattr__ = dict.__setitem__ | |
__delattr__ = dict.__delitem__ | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
if len(args) > 0 and isinstance(args[0], dict): | |
for key, value in self.items(): | |
if isinstance(value, dict): | |
self.__setitem__(key, DotDict(value)) | |
def _generate_equation(size_left, depth_left, latex, tokens): | |
if size_left <= 0: | |
return "" | |
equation = "" | |
pairs, scopes, special = latex.pairs, latex.scopes, latex.special | |
weights = [3, depth_left > 0, depth_left > 0] | |
group, = random.choices([tokens, pairs, scopes], weights=weights) | |
if group is tokens: | |
equation += ' '.join([ | |
random.choice(tokens), | |
_generate_equation(size_left - 1, depth_left, latex, tokens) | |
]) | |
return equation | |
post_scope_size = round(abs(random.gauss(0, size_left / 2))) | |
size_left -= post_scope_size + 1 | |
if group is pairs: | |
pair = random.choice(pairs) | |
equation += ' '.join([ | |
pair[0], | |
_generate_equation(size_left, depth_left - 1, latex, tokens), | |
pair[1], | |
_generate_equation(post_scope_size, depth_left, latex, tokens) | |
]) | |
return equation | |
elif group is scopes: | |
scope_type, scope_group = random.choice(list(scopes.items())) | |
scope_operator = random.choice(scope_group) | |
equation += scope_operator | |
if scope_type == 'single': | |
equation += ' '.join([ | |
special.left_bracket, | |
_generate_equation(size_left, depth_left - 1, latex, tokens) | |
]) | |
elif scope_type == 'double_no_delimiters': | |
equation += ' '.join([ | |
special.left_bracket, | |
_generate_equation(size_left // 2, depth_left - 1, latex, tokens), | |
special.right_bracket + special.left_bracket, | |
_generate_equation(size_left // 2, depth_left - 1, latex, tokens) | |
]) | |
elif scope_type == 'double_with_delimiters': | |
equation += ' '.join([ | |
special.caret, | |
special.left_bracket, | |
_generate_equation(size_left // 2, depth_left - 1, latex, tokens), | |
special.right_bracket, | |
special.underscore, | |
special.left_bracket, | |
_generate_equation(size_left // 2, depth_left - 1, latex, tokens) | |
]) | |
equation += ' '.join([ | |
special.right_bracket, | |
_generate_equation(post_scope_size, depth_left, latex, tokens) | |
]) | |
return equation | |
def generate_equation(latex: DotDict, size, depth=3): | |
""" | |
Generates a random latex equation | |
------- | |
params: | |
:latex: -- dict with tokens to generate equation from | |
:size: -- approximate size of equation | |
:depth: -- max brackets and scope depth | |
""" | |
tokens = [token for group in ['chars', 'greek', 'functions', 'operators', 'spaces'] | |
for token in latex[group]] | |
equation = _generate_equation(size, depth, latex, tokens) | |
return equation | |
def generate_image(directory: str, latex: dict, filename: str, max_length=20, equation_depth=3, | |
pdflatex: str = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex", | |
ghostscript: str = "/external2/dkkoshman/venv/local/gs/bin/gs" | |
): | |
""" | |
Generates a random tex file and corresponding image | |
------- | |
params: | |
:directory: -- dir where to save files | |
:latex: -- dict with parameters to generate tex | |
:filename: -- absolute filename for the generated files | |
:max_length: -- max size of equation | |
:equation_depth: -- max nested level of tex scopes | |
:pdflatex: -- path to pdflatex | |
:ghostscript: -- path to ghostscript | |
""" | |
filepath = os.path.join(directory, filename) | |
equation_length = random.randint(max_length // 2, max_length) | |
latex = DotDict(latex) | |
template = string.Template(latex.template) | |
font, font_options = random.choice(latex.fonts) | |
font_option = random.choice([''] + font_options) | |
fontsize = random.choice(latex.fontsizes) | |
equation = generate_equation(latex, equation_length, depth=equation_depth) | |
tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation) | |
with open(f"{filepath}.tex", mode='w') as file: | |
file.write(tex) | |
try: | |
pdflatex_process = subprocess.run( | |
f"{pdflatex} -output-directory={directory} {filepath}.tex".split(), | |
stderr=subprocess.DEVNULL, | |
stdout=subprocess.DEVNULL, | |
timeout=1 | |
) | |
except subprocess.TimeoutExpired: | |
subprocess.run(f'rm {filepath}.tex'.split()) | |
return | |
if pdflatex_process.returncode != 0: | |
subprocess.run(f'rm {filepath}.tex'.split()) | |
return | |
subprocess.run( | |
f"{ghostscript} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(), | |
stderr=subprocess.DEVNULL, | |
stdout=subprocess.DEVNULL, | |
) | |
def _generate_image_wrapper(args): | |
return generate_image(*args) | |
def generate_data(examples_count) -> None: | |
""" | |
Clears a directory and generates a latex dataset in given directory | |
------- | |
params: | |
:examples_count: - how many latex - image examples to generate | |
""" | |
filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count)) | |
directory = os.path.abspath(DATA_DIR) | |
latex_path = os.path.abspath(LATEX_PATH) | |
with open(latex_path) as file: | |
latex = json.load(file) | |
shutil.rmtree(directory) | |
os.mkdir(directory) | |
def _get_current_relevant_files(): | |
return set(os.path.join(directory, file) for file in os.listdir(directory)) | set( | |
os.path.abspath(file) for file in os.listdir(os.getcwd())) | |
files_before = _get_current_relevant_files() | |
while filenames: | |
with Pool() as pool: | |
list(tqdm.tqdm( | |
pool.imap(_generate_image_wrapper, ((directory, latex, filename) for filename in sorted(filenames))), | |
"Generating images", | |
total=len(filenames) | |
)) | |
existing = set(os.path.splitext(filename)[0] for filename in os.listdir(directory) if filename.endswith('.png')) | |
filenames -= existing | |
files_after = _get_current_relevant_files() | |
files_to_delete = files_after - files_before | |
files_to_delete = list(os.path.join(directory, file) for file in files_to_delete if | |
not file.endswith('.png') and not file.endswith('.tex')) | |
if files_to_delete: | |
subprocess.run(['rm'] + files_to_delete) | |