hantupocong commited on
Commit
0af5be2
·
verified ·
1 Parent(s): 5dc61d6

Update data.py

Browse files
Files changed (1) hide show
  1. data.py +225 -226
data.py CHANGED
@@ -1,227 +1,226 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import Dataset, DataLoader, random_split
4
- import pandas as pd
5
- import numpy as np
6
- from transformers import BertTokenizer
7
- import os
8
- from pose_format import Pose
9
- import matplotlib.pyplot as plt
10
- from matplotlib import animation
11
- from IPython.display import HTML
12
- from fastdtw import fastdtw # Keep this import
13
- from scipy.spatial.distance import cosine
14
- from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
15
-
16
- # ===== KEYPOINT SELECTION =====
17
- selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
18
- NUM_KEYPOINTS = len(selected_keypoint_indices)
19
- POSE_DIM = NUM_KEYPOINTS * 3
20
-
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- # ===== SMOOTHING =====
24
- def selective_smoothing(preds):
25
- smoothed = preds.clone()
26
- body_indices = slice(0, 25 * 3)
27
- for t in range(1, preds.shape[1] - 1):
28
- smoothed[:, t, body_indices] = (
29
- 0.25 * preds[:, t - 1, body_indices] +
30
- 0.5 * preds[:, t, body_indices] +
31
- 0.25 * preds[:, t + 1, body_indices]
32
- )
33
- return smoothed
34
-
35
-
36
- # ===== HAND INDEX SETUP =====
37
- hand_indices = list(range(15 * 3, POSE_DIM)) # hand joints = after body joints in flattened 3D vector
38
- joint_weights = torch.ones(POSE_DIM).to(device)
39
- joint_weights[hand_indices] *= 3.0
40
-
41
- # ===== GLOBAL NORMALIZATION =====
42
- def compute_global_mean_std(pose_folder, csv_file):
43
- data = pd.read_csv(csv_file)
44
- all_poses = []
45
- # Store valid masks separately to ensure correct normalization
46
- all_masks = []
47
-
48
- for filename in data["filename"]:
49
- pose_path = os.path.join(pose_folder, filename)
50
- with open(pose_path, "rb") as f:
51
- pose = Pose.read(f.read())
52
-
53
- keypoints = np.array(selected_keypoint_indices)
54
- # (T, 1, K, 3) -> (T, K, 3)
55
- pose_data = np.squeeze(pose.body.data, axis=1)[:, keypoints, :]
56
- # (T, 1, K) -> (T, K)
57
- confidence = np.squeeze(pose.body.confidence, axis=1)[:, keypoints]
58
-
59
- # Reshape to (T, K*3)
60
- pose_data_flat = pose_data.reshape(pose_data.shape[0], -1)
61
- # Reshape confidence to (T, K*3) - repeat confidence for each coordinate
62
- confidence_flat = np.repeat(confidence, 3, axis=1)
63
-
64
- # Create a mask based on confidence for the flattened data
65
- mask_flat = (confidence_flat > 0.5).astype(np.float32)
66
-
67
- # Append the full pose data and mask for interpolation later
68
- all_poses.append(pose_data_flat)
69
- all_masks.append(mask_flat)
70
-
71
-
72
- # Pad or interpolate all poses and masks to a fixed length (TARGET_NUM_FRAMES)
73
- padded_poses = []
74
- padded_masks = []
75
- for pose_data_flat, mask_flat in zip(all_poses, all_masks):
76
- current_frames = pose_data_flat.shape[0]
77
- if current_frames < TARGET_NUM_FRAMES:
78
- pad_len = TARGET_NUM_FRAMES - current_frames
79
- pose_pad = np.zeros((pad_len, POSE_DIM))
80
- mask_pad = np.zeros((pad_len, POSE_DIM)) # Pad mask with zeros
81
- padded_pose = np.concatenate([pose_data_flat, pose_pad], axis=0)
82
- padded_mask = np.concatenate([mask_flat, mask_pad], axis=0)
83
- else:
84
- indices = np.linspace(0, current_frames - 1, TARGET_NUM_FRAMES).astype(int)
85
- padded_pose = pose_data_flat[indices]
86
- padded_mask = mask_flat[indices]
87
-
88
- padded_poses.append(padded_pose)
89
- padded_masks.append(padded_mask)
90
-
91
- # Stack all padded poses and masks: (Total_Samples * TARGET_NUM_FRAMES, POSE_DIM)
92
- stacked_poses = np.vstack(padded_poses)
93
- stacked_masks = np.vstack(padded_masks)
94
-
95
- # Compute mean and std using the masks to only include valid points
96
- # Weighted average using mask as weights
97
- mean = np.sum(stacked_poses * stacked_masks, axis=0) / (np.sum(stacked_masks, axis=0) + 1e-8) # Add epsilon for stability
98
- # Compute variance, then sqrt for std
99
- variance = np.sum(stacked_masks * (stacked_poses - mean)**2, axis=0) / (np.sum(stacked_masks, axis=0) + 1e-8)
100
- std = np.sqrt(variance)
101
-
102
- std[std == 0] = 1e-8 # Avoid division by zero
103
- return mean, std
104
-
105
- POSE_FOLDER = "/content/drive/MyDrive/pose/words/ase"
106
- CSV_FILE = "/content/drive/MyDrive/pose/words/annotated.csv"
107
- mean_path = "/content/drive/MyDrive/pose/global_mean.npy"
108
- std_path = "/content/drive/MyDrive/pose/global_std.npy"
109
-
110
- if os.path.exists(mean_path) and os.path.exists(std_path):
111
- print("Loading global mean and std from file.")
112
- GLOBAL_MEAN = np.load(mean_path)
113
- GLOBAL_STD = np.load(std_path)
114
- else:
115
- print("Computing global mean and std from dataset.")
116
- GLOBAL_MEAN, GLOBAL_STD = compute_global_mean_std(POSE_FOLDER, CSV_FILE)
117
- # Save the computed mean and std
118
- # Ensure they are not MaskedArrays if the computation somehow produced them
119
- # If compute_global_mean_std is modified to return standard arrays, this is redundant but safe
120
- if isinstance(GLOBAL_MEAN, np.ma.MaskedArray):
121
- GLOBAL_MEAN = GLOBAL_MEAN.data
122
- if isinstance(GLOBAL_STD, np.ma.MaskedArray):
123
- GLOBAL_STD = GLOBAL_STD.data
124
-
125
- np.save(mean_path, GLOBAL_MEAN)
126
- np.save(std_path, GLOBAL_STD)
127
-
128
- GLOBAL_MEAN_T = torch.tensor(GLOBAL_MEAN).float().to(device)
129
- GLOBAL_STD_T = torch.tensor(GLOBAL_STD).float().to(device)
130
-
131
-
132
- class TextToPoseDataset(Dataset):
133
- def __init__(self, csv_file, pose_folder, tokenizer, is_train=True):
134
- self.data = pd.read_csv(csv_file)
135
- self.pose_folder = pose_folder
136
- self.tokenizer = tokenizer
137
- self.is_train = is_train # enable augment only during training
138
-
139
- def __len__(self):
140
- return len(self.data)
141
-
142
- def load_pose_data_and_mask(self, filename):
143
- pose_path = os.path.join(self.pose_folder, filename)
144
- with open(pose_path, "rb") as f:
145
- pose = Pose.read(f.read())
146
-
147
- keypoints = np.array(selected_keypoint_indices)
148
- pose_data = np.squeeze(pose.body.data, axis=1)[:, keypoints, :]
149
- confidence = np.squeeze(pose.body.confidence, axis=1)[:, keypoints]
150
-
151
- return pose_data, confidence
152
-
153
- def apply_augmentations(self, pose_data, confidence):
154
- T = pose_data.shape[0]
155
-
156
- # Temporal warp (resample frame indices with small noise)
157
- if T > TARGET_NUM_FRAMES and np.random.rand() < 0.5:
158
- indices = np.linspace(0, T - 1, TARGET_NUM_FRAMES)
159
- jitter = np.random.uniform(-0.5, 0.5, size=indices.shape)
160
- indices = np.clip(indices + jitter, 0, T - 1).astype(int)
161
- pose_data = pose_data[indices]
162
- confidence = confidence[indices]
163
-
164
- # Mirror (flip X-axis)
165
- if np.random.rand() < 0.3:
166
- pose_data[..., 0] *= -1
167
-
168
- # Jitter (small Gaussian noise)
169
- if np.random.rand() < 0.3:
170
- pose_data += np.random.normal(0, 0.02, pose_data.shape)
171
-
172
- return pose_data, confidence
173
-
174
- def __getitem__(self, idx):
175
- row = self.data.iloc[idx]
176
- filename = row["filename"]
177
- text = row["text"]
178
-
179
- input_ids = self.tokenizer(
180
- text, padding="max_length", truncation=True,
181
- max_length=MAX_TEXT_LEN, return_tensors="pt"
182
- )
183
-
184
- pose_data, confidence = self.load_pose_data_and_mask(filename)
185
-
186
- if self.is_train:
187
- pose_data, confidence = self.apply_augmentations(pose_data, confidence)
188
-
189
- # OLD Flatten
190
- pose_data_flat = pose_data.reshape(pose_data.shape[0], -1)
191
- confidence_flat = np.repeat(confidence, 3, axis=1)
192
- mask_flat = (confidence_flat > 0.5).astype(np.float32)
193
-
194
- # Pad or warp to fixed length
195
- current_frames = pose_data_flat.shape[0]
196
- if current_frames < TARGET_NUM_FRAMES:
197
- pad_len = TARGET_NUM_FRAMES - current_frames
198
- pose_pad = np.zeros((pad_len, POSE_DIM))
199
- mask_pad = np.zeros((pad_len, POSE_DIM))
200
- padded_pose = np.concatenate([pose_data_flat, pose_pad], axis=0)
201
- padded_mask = np.concatenate([mask_flat, mask_pad], axis=0)
202
- else:
203
- indices = np.linspace(0, current_frames - 1, TARGET_NUM_FRAMES).astype(int)
204
- padded_pose = pose_data_flat[indices]
205
- padded_mask = mask_flat[indices]
206
-
207
- # Normalize
208
- normalized_pose = (padded_pose - GLOBAL_MEAN) / GLOBAL_STD
209
-
210
- return (
211
- input_ids.input_ids.squeeze(0),
212
- input_ids.attention_mask.squeeze(0),
213
- torch.tensor(normalized_pose).float(),
214
- torch.tensor(padded_mask).float(),
215
- text
216
- )
217
-
218
-
219
- def collate_fn(batch):
220
- input_ids, attn_masks, poses, masks, words = zip(*batch)
221
- return (
222
- torch.stack(input_ids),
223
- torch.stack(attn_masks),
224
- torch.stack(poses),
225
- torch.stack(masks),
226
- list(words)
227
  )
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader, random_split
4
+ import pandas as pd
5
+ import numpy as np
6
+ from transformers import BertTokenizer
7
+ import os
8
+ from pose_format import Pose
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib import animation
11
+ from fastdtw import fastdtw # Keep this import
12
+ from scipy.spatial.distance import cosine
13
+ from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
14
+
15
+ # ===== KEYPOINT SELECTION =====
16
+ selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
17
+ NUM_KEYPOINTS = len(selected_keypoint_indices)
18
+ POSE_DIM = NUM_KEYPOINTS * 3
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # ===== SMOOTHING =====
23
+ def selective_smoothing(preds):
24
+ smoothed = preds.clone()
25
+ body_indices = slice(0, 25 * 3)
26
+ for t in range(1, preds.shape[1] - 1):
27
+ smoothed[:, t, body_indices] = (
28
+ 0.25 * preds[:, t - 1, body_indices] +
29
+ 0.5 * preds[:, t, body_indices] +
30
+ 0.25 * preds[:, t + 1, body_indices]
31
+ )
32
+ return smoothed
33
+
34
+
35
+ # ===== HAND INDEX SETUP =====
36
+ hand_indices = list(range(15 * 3, POSE_DIM)) # hand joints = after body joints in flattened 3D vector
37
+ joint_weights = torch.ones(POSE_DIM).to(device)
38
+ joint_weights[hand_indices] *= 3.0
39
+
40
+ # ===== GLOBAL NORMALIZATION =====
41
+ def compute_global_mean_std(pose_folder, csv_file):
42
+ data = pd.read_csv(csv_file)
43
+ all_poses = []
44
+ # Store valid masks separately to ensure correct normalization
45
+ all_masks = []
46
+
47
+ for filename in data["filename"]:
48
+ pose_path = os.path.join(pose_folder, filename)
49
+ with open(pose_path, "rb") as f:
50
+ pose = Pose.read(f.read())
51
+
52
+ keypoints = np.array(selected_keypoint_indices)
53
+ # (T, 1, K, 3) -> (T, K, 3)
54
+ pose_data = np.squeeze(pose.body.data, axis=1)[:, keypoints, :]
55
+ # (T, 1, K) -> (T, K)
56
+ confidence = np.squeeze(pose.body.confidence, axis=1)[:, keypoints]
57
+
58
+ # Reshape to (T, K*3)
59
+ pose_data_flat = pose_data.reshape(pose_data.shape[0], -1)
60
+ # Reshape confidence to (T, K*3) - repeat confidence for each coordinate
61
+ confidence_flat = np.repeat(confidence, 3, axis=1)
62
+
63
+ # Create a mask based on confidence for the flattened data
64
+ mask_flat = (confidence_flat > 0.5).astype(np.float32)
65
+
66
+ # Append the full pose data and mask for interpolation later
67
+ all_poses.append(pose_data_flat)
68
+ all_masks.append(mask_flat)
69
+
70
+
71
+ # Pad or interpolate all poses and masks to a fixed length (TARGET_NUM_FRAMES)
72
+ padded_poses = []
73
+ padded_masks = []
74
+ for pose_data_flat, mask_flat in zip(all_poses, all_masks):
75
+ current_frames = pose_data_flat.shape[0]
76
+ if current_frames < TARGET_NUM_FRAMES:
77
+ pad_len = TARGET_NUM_FRAMES - current_frames
78
+ pose_pad = np.zeros((pad_len, POSE_DIM))
79
+ mask_pad = np.zeros((pad_len, POSE_DIM)) # Pad mask with zeros
80
+ padded_pose = np.concatenate([pose_data_flat, pose_pad], axis=0)
81
+ padded_mask = np.concatenate([mask_flat, mask_pad], axis=0)
82
+ else:
83
+ indices = np.linspace(0, current_frames - 1, TARGET_NUM_FRAMES).astype(int)
84
+ padded_pose = pose_data_flat[indices]
85
+ padded_mask = mask_flat[indices]
86
+
87
+ padded_poses.append(padded_pose)
88
+ padded_masks.append(padded_mask)
89
+
90
+ # Stack all padded poses and masks: (Total_Samples * TARGET_NUM_FRAMES, POSE_DIM)
91
+ stacked_poses = np.vstack(padded_poses)
92
+ stacked_masks = np.vstack(padded_masks)
93
+
94
+ # Compute mean and std using the masks to only include valid points
95
+ # Weighted average using mask as weights
96
+ mean = np.sum(stacked_poses * stacked_masks, axis=0) / (np.sum(stacked_masks, axis=0) + 1e-8) # Add epsilon for stability
97
+ # Compute variance, then sqrt for std
98
+ variance = np.sum(stacked_masks * (stacked_poses - mean)**2, axis=0) / (np.sum(stacked_masks, axis=0) + 1e-8)
99
+ std = np.sqrt(variance)
100
+
101
+ std[std == 0] = 1e-8 # Avoid division by zero
102
+ return mean, std
103
+
104
+ #POSE_FOLDER = "/content/drive/MyDrive/pose/words/ase"
105
+ CSV_FILE = "annotated.csv"
106
+ mean_path = "global_mean.npy"
107
+ std_path = "global_std.npy"
108
+
109
+ if os.path.exists(mean_path) and os.path.exists(std_path):
110
+ print("Loading global mean and std from file.")
111
+ GLOBAL_MEAN = np.load(mean_path)
112
+ GLOBAL_STD = np.load(std_path)
113
+ else:
114
+ print("Computing global mean and std from dataset.")
115
+ GLOBAL_MEAN, GLOBAL_STD = compute_global_mean_std(POSE_FOLDER, CSV_FILE)
116
+ # Save the computed mean and std
117
+ # Ensure they are not MaskedArrays if the computation somehow produced them
118
+ # If compute_global_mean_std is modified to return standard arrays, this is redundant but safe
119
+ if isinstance(GLOBAL_MEAN, np.ma.MaskedArray):
120
+ GLOBAL_MEAN = GLOBAL_MEAN.data
121
+ if isinstance(GLOBAL_STD, np.ma.MaskedArray):
122
+ GLOBAL_STD = GLOBAL_STD.data
123
+
124
+ np.save(mean_path, GLOBAL_MEAN)
125
+ np.save(std_path, GLOBAL_STD)
126
+
127
+ GLOBAL_MEAN_T = torch.tensor(GLOBAL_MEAN).float().to(device)
128
+ GLOBAL_STD_T = torch.tensor(GLOBAL_STD).float().to(device)
129
+
130
+
131
+ class TextToPoseDataset(Dataset):
132
+ def __init__(self, csv_file, pose_folder, tokenizer, is_train=True):
133
+ self.data = pd.read_csv(csv_file)
134
+ self.pose_folder = pose_folder
135
+ self.tokenizer = tokenizer
136
+ self.is_train = is_train # enable augment only during training
137
+
138
+ def __len__(self):
139
+ return len(self.data)
140
+
141
+ def load_pose_data_and_mask(self, filename):
142
+ pose_path = os.path.join(self.pose_folder, filename)
143
+ with open(pose_path, "rb") as f:
144
+ pose = Pose.read(f.read())
145
+
146
+ keypoints = np.array(selected_keypoint_indices)
147
+ pose_data = np.squeeze(pose.body.data, axis=1)[:, keypoints, :]
148
+ confidence = np.squeeze(pose.body.confidence, axis=1)[:, keypoints]
149
+
150
+ return pose_data, confidence
151
+
152
+ def apply_augmentations(self, pose_data, confidence):
153
+ T = pose_data.shape[0]
154
+
155
+ # Temporal warp (resample frame indices with small noise)
156
+ if T > TARGET_NUM_FRAMES and np.random.rand() < 0.5:
157
+ indices = np.linspace(0, T - 1, TARGET_NUM_FRAMES)
158
+ jitter = np.random.uniform(-0.5, 0.5, size=indices.shape)
159
+ indices = np.clip(indices + jitter, 0, T - 1).astype(int)
160
+ pose_data = pose_data[indices]
161
+ confidence = confidence[indices]
162
+
163
+ # Mirror (flip X-axis)
164
+ if np.random.rand() < 0.3:
165
+ pose_data[..., 0] *= -1
166
+
167
+ # Jitter (small Gaussian noise)
168
+ if np.random.rand() < 0.3:
169
+ pose_data += np.random.normal(0, 0.02, pose_data.shape)
170
+
171
+ return pose_data, confidence
172
+
173
+ def __getitem__(self, idx):
174
+ row = self.data.iloc[idx]
175
+ filename = row["filename"]
176
+ text = row["text"]
177
+
178
+ input_ids = self.tokenizer(
179
+ text, padding="max_length", truncation=True,
180
+ max_length=MAX_TEXT_LEN, return_tensors="pt"
181
+ )
182
+
183
+ pose_data, confidence = self.load_pose_data_and_mask(filename)
184
+
185
+ if self.is_train:
186
+ pose_data, confidence = self.apply_augmentations(pose_data, confidence)
187
+
188
+ # OLD Flatten
189
+ pose_data_flat = pose_data.reshape(pose_data.shape[0], -1)
190
+ confidence_flat = np.repeat(confidence, 3, axis=1)
191
+ mask_flat = (confidence_flat > 0.5).astype(np.float32)
192
+
193
+ # Pad or warp to fixed length
194
+ current_frames = pose_data_flat.shape[0]
195
+ if current_frames < TARGET_NUM_FRAMES:
196
+ pad_len = TARGET_NUM_FRAMES - current_frames
197
+ pose_pad = np.zeros((pad_len, POSE_DIM))
198
+ mask_pad = np.zeros((pad_len, POSE_DIM))
199
+ padded_pose = np.concatenate([pose_data_flat, pose_pad], axis=0)
200
+ padded_mask = np.concatenate([mask_flat, mask_pad], axis=0)
201
+ else:
202
+ indices = np.linspace(0, current_frames - 1, TARGET_NUM_FRAMES).astype(int)
203
+ padded_pose = pose_data_flat[indices]
204
+ padded_mask = mask_flat[indices]
205
+
206
+ # Normalize
207
+ normalized_pose = (padded_pose - GLOBAL_MEAN) / GLOBAL_STD
208
+
209
+ return (
210
+ input_ids.input_ids.squeeze(0),
211
+ input_ids.attention_mask.squeeze(0),
212
+ torch.tensor(normalized_pose).float(),
213
+ torch.tensor(padded_mask).float(),
214
+ text
215
+ )
216
+
217
+
218
+ def collate_fn(batch):
219
+ input_ids, attn_masks, poses, masks, words = zip(*batch)
220
+ return (
221
+ torch.stack(input_ids),
222
+ torch.stack(attn_masks),
223
+ torch.stack(poses),
224
+ torch.stack(masks),
225
+ list(words)
 
226
  )