import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, random_split import pandas as pd import numpy as np from transformers import BertTokenizer import os from pose_format import Pose import matplotlib.pyplot as plt from matplotlib import animation from fastdtw import fastdtw # Keep this import from scipy.spatial.distance import cosine from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED from transformers import BertModel # ✅ Import BERT tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2") # ===== KEYPOINT SELECTION ===== selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543]) NUM_KEYPOINTS = len(selected_keypoint_indices) POSE_DIM = NUM_KEYPOINTS * 3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class TextToPoseSeq2Seq(nn.Module): def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES): super().__init__() self.hidden_dim = hidden_dim self.target_len = target_len self.pose_dim = pose_dim # === BERT Encoder === self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2") # === GRU Decoder === self.input_proj = nn.Linear(pose_dim, hidden_dim) bert_hidden = self.encoder.config.hidden_size self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim) self.dropout = nn.Dropout(0.3) self.fc_pose = nn.Linear(hidden_dim, pose_dim) self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS) self.output_scale = 1.0 def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO): B = input_ids.size(0) pose_outputs = [] conf_outputs = [] input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device) # === BERT Encoding === encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) context = encoder_outputs.last_hidden_state[:, 0, :] # [CLS] token h = torch.zeros(B, self.hidden_dim).to(input_ids.device) for t in range(self.target_len): use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio if use_teacher and t > 0: input_pose = target_pose[:, t - 1, :] elif t > 0: input_pose = pose_outputs[-1].squeeze(1).detach() pose_emb = self.input_proj(input_pose) gru_input = torch.cat([pose_emb, context], dim=-1) h = self.gru_cell(gru_input, h) h = self.dropout(h) pred_pose = self.fc_pose(h) * self.output_scale pred_conf = torch.sigmoid(self.fc_conf(h)) pose_outputs.append(pred_pose.unsqueeze(1)) conf_outputs.append(pred_conf.unsqueeze(1)) input_pose = pred_pose.detach() return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1) # ===== METRICS ===== def mpjpe(pred, target, mask=None): # Shapes: (B, T, POSE_DIM) pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3) target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3) error = torch.norm(pred - target, dim=3) # (B, T, NUM_KEYPOINTS) if mask is not None: mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS) # (B, T, K) masked_error = error * mask return masked_error.sum() / (mask.sum() + 1e-8) else: return error.mean() def per_joint_mpjpe(pred, target, mask=None): pred = pred.view(-1, NUM_KEYPOINTS, 3) target = target.view(-1, NUM_KEYPOINTS, 3) error = torch.norm(pred - target, dim=2) # (B*T, K) if mask is not None: mask = mask.view(-1, NUM_KEYPOINTS) # (B*T, K) masked_error = error * mask joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8) return joint_means.cpu().numpy() else: return error.mean(dim=0).cpu().numpy() def pose_velocity(pose_seq): # pose_seq shape is assumed to be (B, T, POSE_DIM) # Calculate difference along the time dimension (dim=1) diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :] # Reshape to (B, T-1, NUM_KEYPOINTS, 3) to get per-joint velocity diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3) # Norm over coordinate dimension (dim=3), then mean over batch and time return torch.norm(diffs, dim=3).mean().item() def cosine_similarity(pred, target): # pred and target are (B, T, POSE_DIM) pred = pred.view(-1, POSE_DIM).cpu().numpy() target = target.view(-1, POSE_DIM).cpu().numpy() # Cosine similarity is usually calculated per sample or per timestep. # Calculating on flattened data across batch and time might not be meaningful. # Returning a scalar mean of pairwise similarities could be an alternative. # For simplicity, calculating similarity of flattened arrays. if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0: return 0.0 # Handle zero vectors return 1 - cosine(pred.flatten(), target.flatten()) def dtw_distance(pred, target): # pred and target are (B, T, POSE_DIM) # DTW is typically computed sequence-wise (T, POSE_DIM) # Computing on the first sample of the batch as an example pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy() target_seq = target[0].view(-1, POSE_DIM).cpu().numpy() # Use Euclidean distance as the distance metric for DTW distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b)) return distance