SkillForge45 commited on
Commit
9745119
·
verified ·
1 Parent(s): 92f686f

Create dataloader.py

Browse files
Files changed (1) hide show
  1. dataloader.py +56 -0
dataloader.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets import load_dataset
3
+ from torch.utils.data import DataLoader
4
+ from transformers import BertTokenizer
5
+ import decord
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+
10
+ FRAMES = 400
11
+ H, W = 780, 780
12
+ BATCH_SIZE = 8
13
+ TEXT_MAX_LEN = 32
14
+
15
+
16
+ dataset = load_dataset("minh132/pexels-videos", split="train")
17
+
18
+
19
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
20
+
21
+
22
+ class VideoDataset(torch.utils.data.Dataset):
23
+ def __init__(self, dataset):
24
+ self.dataset = dataset
25
+ self.decord_ctx = decord.cpu(0) # CPU decoding
26
+
27
+ def __len__(self):
28
+ return len(self.dataset)
29
+
30
+ def __getitem__(self, idx):
31
+ item = self.dataset[idx]
32
+
33
+
34
+ vr = decord.VideoReader(item["video_path"], ctx=self.decord_ctx)
35
+ frame_indices = np.linspace(0, len(vr)-1, FRAMES, dtype=int)
36
+ video = vr.get_batch(frame_indices).numpy() # (FRAMES, H, W, 3)
37
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # (3, FRAMES, H, W)
38
+
39
+
40
+ video = F.interpolate(video, size=(H, W), mode="bilinear")
41
+ video = (video / 255.0) * 2 - 1 # [-1, 1]
42
+
43
+
44
+ text = tokenizer(
45
+ item["caption"],
46
+ padding="max_length",
47
+ truncation=True,
48
+ max_length=TEXT_MAX_LEN,
49
+ return_tensors="pt"
50
+ ).input_ids.squeeze(0)
51
+
52
+ return {"video": video, "text": text}
53
+
54
+
55
+ dataset = VideoDataset(dataset)
56
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)