Spaces:
Runtime error
Runtime error
File size: 763 Bytes
6e82d4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform
import torch
from torch.utils.data import DataLoader
if __name__ == '__main__':
image_transform = RandomizeImageTransform()
tex_transform = ExtractEquationFromTexTransform()
dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform)
train_dataset, test_dataset = torch.utils.data.random_split(
dataset,
[len(dataset) * 9 // 10, len(dataset) // 10]
)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
collate_fn=train_dataset.collate_fn)
batch = next(iter(train_dataloader))
print(batch['texs'])
|