File size: 763 Bytes
6e82d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform

import torch
from torch.utils.data import DataLoader

if __name__ == '__main__':
    image_transform = RandomizeImageTransform()
    tex_transform = ExtractEquationFromTexTransform()
    dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform)

    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset,
        [len(dataset) * 9 // 10, len(dataset) // 10]
    )

    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
                                  collate_fn=train_dataset.collate_fn)
    batch = next(iter(train_dataloader))
    print(batch['texs'])