File size: 5,707 Bytes
e55790d b5736a5 e55790d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
from transformers import BertTokenizer
import os
from pose_format import Pose
import matplotlib.pyplot as plt
from matplotlib import animation
from fastdtw import fastdtw # Keep this import
from scipy.spatial.distance import cosine
from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
from transformers import BertModel # ✅ Import BERT
tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2")
# ===== KEYPOINT SELECTION =====
selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
NUM_KEYPOINTS = len(selected_keypoint_indices)
POSE_DIM = NUM_KEYPOINTS * 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class TextToPoseSeq2Seq(nn.Module):
def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES):
super().__init__()
self.hidden_dim = hidden_dim
self.target_len = target_len
self.pose_dim = pose_dim
# === BERT Encoder ===
self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2")
# === GRU Decoder ===
self.input_proj = nn.Linear(pose_dim, hidden_dim)
bert_hidden = self.encoder.config.hidden_size
self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim)
self.dropout = nn.Dropout(0.3)
self.fc_pose = nn.Linear(hidden_dim, pose_dim)
self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS)
self.output_scale = 1.0
def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO):
B = input_ids.size(0)
pose_outputs = []
conf_outputs = []
input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device)
# === BERT Encoding ===
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
context = encoder_outputs.last_hidden_state[:, 0, :] # [CLS] token
h = torch.zeros(B, self.hidden_dim).to(input_ids.device)
for t in range(self.target_len):
use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio
if use_teacher and t > 0:
input_pose = target_pose[:, t - 1, :]
elif t > 0:
input_pose = pose_outputs[-1].squeeze(1).detach()
pose_emb = self.input_proj(input_pose)
gru_input = torch.cat([pose_emb, context], dim=-1)
h = self.gru_cell(gru_input, h)
h = self.dropout(h)
pred_pose = self.fc_pose(h) * self.output_scale
pred_conf = torch.sigmoid(self.fc_conf(h))
pose_outputs.append(pred_pose.unsqueeze(1))
conf_outputs.append(pred_conf.unsqueeze(1))
input_pose = pred_pose.detach()
return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1)
# ===== METRICS =====
def mpjpe(pred, target, mask=None):
# Shapes: (B, T, POSE_DIM)
pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3)
target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3)
error = torch.norm(pred - target, dim=3) # (B, T, NUM_KEYPOINTS)
if mask is not None:
mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS) # (B, T, K)
masked_error = error * mask
return masked_error.sum() / (mask.sum() + 1e-8)
else:
return error.mean()
def per_joint_mpjpe(pred, target, mask=None):
pred = pred.view(-1, NUM_KEYPOINTS, 3)
target = target.view(-1, NUM_KEYPOINTS, 3)
error = torch.norm(pred - target, dim=2) # (B*T, K)
if mask is not None:
mask = mask.view(-1, NUM_KEYPOINTS) # (B*T, K)
masked_error = error * mask
joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8)
return joint_means.cpu().numpy()
else:
return error.mean(dim=0).cpu().numpy()
def pose_velocity(pose_seq):
# pose_seq shape is assumed to be (B, T, POSE_DIM)
# Calculate difference along the time dimension (dim=1)
diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :]
# Reshape to (B, T-1, NUM_KEYPOINTS, 3) to get per-joint velocity
diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3)
# Norm over coordinate dimension (dim=3), then mean over batch and time
return torch.norm(diffs, dim=3).mean().item()
def cosine_similarity(pred, target):
# pred and target are (B, T, POSE_DIM)
pred = pred.view(-1, POSE_DIM).cpu().numpy()
target = target.view(-1, POSE_DIM).cpu().numpy()
# Cosine similarity is usually calculated per sample or per timestep.
# Calculating on flattened data across batch and time might not be meaningful.
# Returning a scalar mean of pairwise similarities could be an alternative.
# For simplicity, calculating similarity of flattened arrays.
if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0:
return 0.0 # Handle zero vectors
return 1 - cosine(pred.flatten(), target.flatten())
def dtw_distance(pred, target):
# pred and target are (B, T, POSE_DIM)
# DTW is typically computed sequence-wise (T, POSE_DIM)
# Computing on the first sample of the batch as an example
pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy()
target_seq = target[0].view(-1, POSE_DIM).cpu().numpy()
# Use Euclidean distance as the distance metric for DTW
distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b))
return distance
|