ML2TransformerApp / train.py
dkoshman
data_preprocessing, base train script
6e82d4a
raw
history blame
763 Bytes
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform
import torch
from torch.utils.data import DataLoader
if __name__ == '__main__':
image_transform = RandomizeImageTransform()
tex_transform = ExtractEquationFromTexTransform()
dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform)
train_dataset, test_dataset = torch.utils.data.random_split(
dataset,
[len(dataset) * 9 // 10, len(dataset) // 10]
)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
collate_fn=train_dataset.collate_fn)
batch = next(iter(train_dataloader))
print(batch['texs'])