Spaces:
Runtime error
Runtime error
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform | |
import torch | |
from torch.utils.data import DataLoader | |
if __name__ == '__main__': | |
image_transform = RandomizeImageTransform() | |
tex_transform = ExtractEquationFromTexTransform() | |
dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform) | |
train_dataset, test_dataset = torch.utils.data.random_split( | |
dataset, | |
[len(dataset) * 9 // 10, len(dataset) // 10] | |
) | |
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16, | |
collate_fn=train_dataset.collate_fn) | |
batch = next(iter(train_dataloader)) | |
print(batch['texs']) | |