File size: 5,707 Bytes
e55790d
 
 
 
 
 
 
 
 
 
 
 
b5736a5
e55790d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
from transformers import BertTokenizer
import os
from pose_format import Pose
import matplotlib.pyplot as plt
from matplotlib import animation
from fastdtw import fastdtw # Keep this import
from scipy.spatial.distance import cosine
from config import MAX_TEXT_LEN, TARGET_NUM_FRAMES, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
from transformers import BertModel  # ✅ Import BERT

tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2")

# ===== KEYPOINT SELECTION =====
selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
NUM_KEYPOINTS = len(selected_keypoint_indices)
POSE_DIM = NUM_KEYPOINTS * 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TextToPoseSeq2Seq(nn.Module):
    def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.target_len = target_len
        self.pose_dim = pose_dim

        # === BERT Encoder ===
        self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2")

        # === GRU Decoder ===
        self.input_proj = nn.Linear(pose_dim, hidden_dim)
        bert_hidden = self.encoder.config.hidden_size
        self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim)
        self.dropout = nn.Dropout(0.3)

        self.fc_pose = nn.Linear(hidden_dim, pose_dim)
        self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS)
        self.output_scale = 1.0

    def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO):
        B = input_ids.size(0)
        pose_outputs = []
        conf_outputs = []
        input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device)

        # === BERT Encoding ===
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        context = encoder_outputs.last_hidden_state[:, 0, :]  # [CLS] token

        h = torch.zeros(B, self.hidden_dim).to(input_ids.device)

        for t in range(self.target_len):
            use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio
            if use_teacher and t > 0:
                input_pose = target_pose[:, t - 1, :]
            elif t > 0:
                input_pose = pose_outputs[-1].squeeze(1).detach()

            pose_emb = self.input_proj(input_pose)
            gru_input = torch.cat([pose_emb, context], dim=-1)
            h = self.gru_cell(gru_input, h)
            h = self.dropout(h)

            pred_pose = self.fc_pose(h) * self.output_scale
            pred_conf = torch.sigmoid(self.fc_conf(h))

            pose_outputs.append(pred_pose.unsqueeze(1))
            conf_outputs.append(pred_conf.unsqueeze(1))
            input_pose = pred_pose.detach()

        return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1)

# ===== METRICS =====
def mpjpe(pred, target, mask=None):
    # Shapes: (B, T, POSE_DIM)
    pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3)
    target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3)

    error = torch.norm(pred - target, dim=3)  # (B, T, NUM_KEYPOINTS)

    if mask is not None:
        mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS)  # (B, T, K)
        masked_error = error * mask
        return masked_error.sum() / (mask.sum() + 1e-8)
    else:
        return error.mean()


def per_joint_mpjpe(pred, target, mask=None):
    pred = pred.view(-1, NUM_KEYPOINTS, 3)
    target = target.view(-1, NUM_KEYPOINTS, 3)
    error = torch.norm(pred - target, dim=2)  # (B*T, K)

    if mask is not None:
        mask = mask.view(-1, NUM_KEYPOINTS)  # (B*T, K)
        masked_error = error * mask
        joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8)
        return joint_means.cpu().numpy()
    else:
        return error.mean(dim=0).cpu().numpy()

def pose_velocity(pose_seq):
    # pose_seq shape is assumed to be (B, T, POSE_DIM)
    # Calculate difference along the time dimension (dim=1)
    diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :]
    # Reshape to (B, T-1, NUM_KEYPOINTS, 3) to get per-joint velocity
    diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3)
    # Norm over coordinate dimension (dim=3), then mean over batch and time
    return torch.norm(diffs, dim=3).mean().item()


def cosine_similarity(pred, target):
    # pred and target are (B, T, POSE_DIM)
    pred = pred.view(-1, POSE_DIM).cpu().numpy()
    target = target.view(-1, POSE_DIM).cpu().numpy()
    # Cosine similarity is usually calculated per sample or per timestep.
    # Calculating on flattened data across batch and time might not be meaningful.
    # Returning a scalar mean of pairwise similarities could be an alternative.
    # For simplicity, calculating similarity of flattened arrays.
    if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0:
        return 0.0 # Handle zero vectors
    return 1 - cosine(pred.flatten(), target.flatten())

def dtw_distance(pred, target):
    # pred and target are (B, T, POSE_DIM)
    # DTW is typically computed sequence-wise (T, POSE_DIM)
    # Computing on the first sample of the batch as an example
    pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy()
    target_seq = target[0].view(-1, POSE_DIM).cpu().numpy()
    # Use Euclidean distance as the distance metric for DTW
    distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b))
    return distance