File size: 8,615 Bytes
0af5be2 6a74642 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
from transformers import BertTokenizer
import os
from pose_format import Pose
import matplotlib.pyplot as plt
from matplotlib import animation
from fastdtw import fastdtw # Keep this import
from scipy.spatial.distance import cosine
from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
# ===== KEYPOINT SELECTION =====
selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
NUM_KEYPOINTS = len(selected_keypoint_indices)
POSE_DIM = NUM_KEYPOINTS * 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===== SMOOTHING =====
def selective_smoothing(preds):
smoothed = preds.clone()
body_indices = slice(0, 25 * 3)
for t in range(1, preds.shape[1] - 1):
smoothed[:, t, body_indices] = (
0.25 * preds[:, t - 1, body_indices] +
0.5 * preds[:, t, body_indices] +
0.25 * preds[:, t + 1, body_indices]
)
return smoothed
# ===== HAND INDEX SETUP =====
hand_indices = list(range(15 * 3, POSE_DIM)) # hand joints = after body joints in flattened 3D vector
joint_weights = torch.ones(POSE_DIM).to(device)
joint_weights[hand_indices] *= 3.0
# ===== GLOBAL NORMALIZATION =====
def compute_global_mean_std(pose_folder, csv_file):
data = pd.read_csv(csv_file)
all_poses = []
# Store valid masks separately to ensure correct normalization
all_masks = []
for filename in data["filename"]:
pose_path = os.path.join(pose_folder, filename)
with open(pose_path, "rb") as f:
pose = Pose.read(f.read())
keypoints = np.array(selected_keypoint_indices)
# (T, 1, K, 3) -> (T, K, 3)
pose_data = np.squeeze(pose.body.data, axis=1)[:, keypoints, :]
# (T, 1, K) -> (T, K)
confidence = np.squeeze(pose.body.confidence, axis=1)[:, keypoints]
# Reshape to (T, K*3)
pose_data_flat = pose_data.reshape(pose_data.shape[0], -1)
# Reshape confidence to (T, K*3) - repeat confidence for each coordinate
confidence_flat = np.repeat(confidence, 3, axis=1)
# Create a mask based on confidence for the flattened data
mask_flat = (confidence_flat > 0.5).astype(np.float32)
# Append the full pose data and mask for interpolation later
all_poses.append(pose_data_flat)
all_masks.append(mask_flat)
# Pad or interpolate all poses and masks to a fixed length (TARGET_NUM_FRAMES)
padded_poses = []
padded_masks = []
for pose_data_flat, mask_flat in zip(all_poses, all_masks):
current_frames = pose_data_flat.shape[0]
if current_frames < TARGET_NUM_FRAMES:
pad_len = TARGET_NUM_FRAMES - current_frames
pose_pad = np.zeros((pad_len, POSE_DIM))
mask_pad = np.zeros((pad_len, POSE_DIM)) # Pad mask with zeros
padded_pose = np.concatenate([pose_data_flat, pose_pad], axis=0)
padded_mask = np.concatenate([mask_flat, mask_pad], axis=0)
else:
indices = np.linspace(0, current_frames - 1, TARGET_NUM_FRAMES).astype(int)
padded_pose = pose_data_flat[indices]
padded_mask = mask_flat[indices]
padded_poses.append(padded_pose)
padded_masks.append(padded_mask)
# Stack all padded poses and masks: (Total_Samples * TARGET_NUM_FRAMES, POSE_DIM)
stacked_poses = np.vstack(padded_poses)
stacked_masks = np.vstack(padded_masks)
# Compute mean and std using the masks to only include valid points
# Weighted average using mask as weights
mean = np.sum(stacked_poses * stacked_masks, axis=0) / (np.sum(stacked_masks, axis=0) + 1e-8) # Add epsilon for stability
# Compute variance, then sqrt for std
variance = np.sum(stacked_masks * (stacked_poses - mean)**2, axis=0) / (np.sum(stacked_masks, axis=0) + 1e-8)
std = np.sqrt(variance)
std[std == 0] = 1e-8 # Avoid division by zero
return mean, std
#POSE_FOLDER = "/content/drive/MyDrive/pose/words/ase"
CSV_FILE = "annotated.csv"
mean_path = "global_mean.npy"
std_path = "global_std.npy"
if os.path.exists(mean_path) and os.path.exists(std_path):
print("Loading global mean and std from file.")
GLOBAL_MEAN = np.load(mean_path)
GLOBAL_STD = np.load(std_path)
else:
print("Computing global mean and std from dataset.")
GLOBAL_MEAN, GLOBAL_STD = compute_global_mean_std(POSE_FOLDER, CSV_FILE)
# Save the computed mean and std
# Ensure they are not MaskedArrays if the computation somehow produced them
# If compute_global_mean_std is modified to return standard arrays, this is redundant but safe
if isinstance(GLOBAL_MEAN, np.ma.MaskedArray):
GLOBAL_MEAN = GLOBAL_MEAN.data
if isinstance(GLOBAL_STD, np.ma.MaskedArray):
GLOBAL_STD = GLOBAL_STD.data
np.save(mean_path, GLOBAL_MEAN)
np.save(std_path, GLOBAL_STD)
GLOBAL_MEAN_T = torch.tensor(GLOBAL_MEAN).float().to(device)
GLOBAL_STD_T = torch.tensor(GLOBAL_STD).float().to(device)
class TextToPoseDataset(Dataset):
def __init__(self, csv_file, pose_folder, tokenizer, is_train=True):
self.data = pd.read_csv(csv_file)
self.pose_folder = pose_folder
self.tokenizer = tokenizer
self.is_train = is_train # enable augment only during training
def __len__(self):
return len(self.data)
def load_pose_data_and_mask(self, filename):
pose_path = os.path.join(self.pose_folder, filename)
with open(pose_path, "rb") as f:
pose = Pose.read(f.read())
keypoints = np.array(selected_keypoint_indices)
pose_data = np.squeeze(pose.body.data, axis=1)[:, keypoints, :]
confidence = np.squeeze(pose.body.confidence, axis=1)[:, keypoints]
return pose_data, confidence
def apply_augmentations(self, pose_data, confidence):
T = pose_data.shape[0]
# Temporal warp (resample frame indices with small noise)
if T > TARGET_NUM_FRAMES and np.random.rand() < 0.5:
indices = np.linspace(0, T - 1, TARGET_NUM_FRAMES)
jitter = np.random.uniform(-0.5, 0.5, size=indices.shape)
indices = np.clip(indices + jitter, 0, T - 1).astype(int)
pose_data = pose_data[indices]
confidence = confidence[indices]
# Mirror (flip X-axis)
if np.random.rand() < 0.3:
pose_data[..., 0] *= -1
# Jitter (small Gaussian noise)
if np.random.rand() < 0.3:
pose_data += np.random.normal(0, 0.02, pose_data.shape)
return pose_data, confidence
def __getitem__(self, idx):
row = self.data.iloc[idx]
filename = row["filename"]
text = row["text"]
input_ids = self.tokenizer(
text, padding="max_length", truncation=True,
max_length=MAX_TEXT_LEN, return_tensors="pt"
)
pose_data, confidence = self.load_pose_data_and_mask(filename)
if self.is_train:
pose_data, confidence = self.apply_augmentations(pose_data, confidence)
# OLD Flatten
pose_data_flat = pose_data.reshape(pose_data.shape[0], -1)
confidence_flat = np.repeat(confidence, 3, axis=1)
mask_flat = (confidence_flat > 0.5).astype(np.float32)
# Pad or warp to fixed length
current_frames = pose_data_flat.shape[0]
if current_frames < TARGET_NUM_FRAMES:
pad_len = TARGET_NUM_FRAMES - current_frames
pose_pad = np.zeros((pad_len, POSE_DIM))
mask_pad = np.zeros((pad_len, POSE_DIM))
padded_pose = np.concatenate([pose_data_flat, pose_pad], axis=0)
padded_mask = np.concatenate([mask_flat, mask_pad], axis=0)
else:
indices = np.linspace(0, current_frames - 1, TARGET_NUM_FRAMES).astype(int)
padded_pose = pose_data_flat[indices]
padded_mask = mask_flat[indices]
# Normalize
normalized_pose = (padded_pose - GLOBAL_MEAN) / GLOBAL_STD
return (
input_ids.input_ids.squeeze(0),
input_ids.attention_mask.squeeze(0),
torch.tensor(normalized_pose).float(),
torch.tensor(padded_mask).float(),
text
)
def collate_fn(batch):
input_ids, attn_masks, poses, masks, words = zip(*batch)
return (
torch.stack(input_ids),
torch.stack(attn_masks),
torch.stack(poses),
torch.stack(masks),
list(words)
) |