hantupocong commited on
Commit
e55790d
·
verified ·
1 Parent(s): 0af5be2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +137 -138
model.py CHANGED
@@ -1,138 +1,137 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import Dataset, DataLoader, random_split
4
- import pandas as pd
5
- import numpy as np
6
- from transformers import BertTokenizer
7
- import os
8
- from pose_format import Pose
9
- import matplotlib.pyplot as plt
10
- from matplotlib import animation
11
- from IPython.display import HTML
12
- from fastdtw import fastdtw # Keep this import
13
- from scipy.spatial.distance import cosine
14
- from config import MAX_TEXT_LEN, TARGET_NUM_FRAME, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
15
- from transformers import BertModel # ✅ Import BERT
16
-
17
- tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2")
18
-
19
- # ===== KEYPOINT SELECTION =====
20
- selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
21
- NUM_KEYPOINTS = len(selected_keypoint_indices)
22
- POSE_DIM = NUM_KEYPOINTS * 3
23
-
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
-
26
- class TextToPoseSeq2Seq(nn.Module):
27
- def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES):
28
- super().__init__()
29
- self.hidden_dim = hidden_dim
30
- self.target_len = target_len
31
- self.pose_dim = pose_dim
32
-
33
- # === BERT Encoder ===
34
- self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2")
35
-
36
- # === GRU Decoder ===
37
- self.input_proj = nn.Linear(pose_dim, hidden_dim)
38
- bert_hidden = self.encoder.config.hidden_size
39
- self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim)
40
- self.dropout = nn.Dropout(0.3)
41
-
42
- self.fc_pose = nn.Linear(hidden_dim, pose_dim)
43
- self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS)
44
- self.output_scale = 1.0
45
-
46
- def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO):
47
- B = input_ids.size(0)
48
- pose_outputs = []
49
- conf_outputs = []
50
- input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device)
51
-
52
- # === BERT Encoding ===
53
- encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
54
- context = encoder_outputs.last_hidden_state[:, 0, :] # [CLS] token
55
-
56
- h = torch.zeros(B, self.hidden_dim).to(input_ids.device)
57
-
58
- for t in range(self.target_len):
59
- use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio
60
- if use_teacher and t > 0:
61
- input_pose = target_pose[:, t - 1, :]
62
- elif t > 0:
63
- input_pose = pose_outputs[-1].squeeze(1).detach()
64
-
65
- pose_emb = self.input_proj(input_pose)
66
- gru_input = torch.cat([pose_emb, context], dim=-1)
67
- h = self.gru_cell(gru_input, h)
68
- h = self.dropout(h)
69
-
70
- pred_pose = self.fc_pose(h) * self.output_scale
71
- pred_conf = torch.sigmoid(self.fc_conf(h))
72
-
73
- pose_outputs.append(pred_pose.unsqueeze(1))
74
- conf_outputs.append(pred_conf.unsqueeze(1))
75
- input_pose = pred_pose.detach()
76
-
77
- return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1)
78
-
79
- # ===== METRICS =====
80
- def mpjpe(pred, target, mask=None):
81
- # Shapes: (B, T, POSE_DIM)
82
- pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3)
83
- target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3)
84
-
85
- error = torch.norm(pred - target, dim=3) # (B, T, NUM_KEYPOINTS)
86
-
87
- if mask is not None:
88
- mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS) # (B, T, K)
89
- masked_error = error * mask
90
- return masked_error.sum() / (mask.sum() + 1e-8)
91
- else:
92
- return error.mean()
93
-
94
-
95
- def per_joint_mpjpe(pred, target, mask=None):
96
- pred = pred.view(-1, NUM_KEYPOINTS, 3)
97
- target = target.view(-1, NUM_KEYPOINTS, 3)
98
- error = torch.norm(pred - target, dim=2) # (B*T, K)
99
-
100
- if mask is not None:
101
- mask = mask.view(-1, NUM_KEYPOINTS) # (B*T, K)
102
- masked_error = error * mask
103
- joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8)
104
- return joint_means.cpu().numpy()
105
- else:
106
- return error.mean(dim=0).cpu().numpy()
107
-
108
- def pose_velocity(pose_seq):
109
- # pose_seq shape is assumed to be (B, T, POSE_DIM)
110
- # Calculate difference along the time dimension (dim=1)
111
- diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :]
112
- # Reshape to (B, T-1, NUM_KEYPOINTS, 3) to get per-joint velocity
113
- diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3)
114
- # Norm over coordinate dimension (dim=3), then mean over batch and time
115
- return torch.norm(diffs, dim=3).mean().item()
116
-
117
-
118
- def cosine_similarity(pred, target):
119
- # pred and target are (B, T, POSE_DIM)
120
- pred = pred.view(-1, POSE_DIM).cpu().numpy()
121
- target = target.view(-1, POSE_DIM).cpu().numpy()
122
- # Cosine similarity is usually calculated per sample or per timestep.
123
- # Calculating on flattened data across batch and time might not be meaningful.
124
- # Returning a scalar mean of pairwise similarities could be an alternative.
125
- # For simplicity, calculating similarity of flattened arrays.
126
- if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0:
127
- return 0.0 # Handle zero vectors
128
- return 1 - cosine(pred.flatten(), target.flatten())
129
-
130
- def dtw_distance(pred, target):
131
- # pred and target are (B, T, POSE_DIM)
132
- # DTW is typically computed sequence-wise (T, POSE_DIM)
133
- # Computing on the first sample of the batch as an example
134
- pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy()
135
- target_seq = target[0].view(-1, POSE_DIM).cpu().numpy()
136
- # Use Euclidean distance as the distance metric for DTW
137
- distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b))
138
- return distance
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader, random_split
4
+ import pandas as pd
5
+ import numpy as np
6
+ from transformers import BertTokenizer
7
+ import os
8
+ from pose_format import Pose
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib import animation
11
+ from fastdtw import fastdtw # Keep this import
12
+ from scipy.spatial.distance import cosine
13
+ from config import MAX_TEXT_LEN, TARGET_NUM_FRAME, BATCH_SIZE, TEACHER_FORCING_RATIO, SMOOTHING_ENABLED
14
+ from transformers import BertModel # Import BERT
15
+
16
+ tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2")
17
+
18
+ # ===== KEYPOINT SELECTION =====
19
+ selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543])
20
+ NUM_KEYPOINTS = len(selected_keypoint_indices)
21
+ POSE_DIM = NUM_KEYPOINTS * 3
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ class TextToPoseSeq2Seq(nn.Module):
26
+ def __init__(self, vocab_size, hidden_dim=512, pose_dim=POSE_DIM, max_len=MAX_TEXT_LEN, target_len=TARGET_NUM_FRAMES):
27
+ super().__init__()
28
+ self.hidden_dim = hidden_dim
29
+ self.target_len = target_len
30
+ self.pose_dim = pose_dim
31
+
32
+ # === BERT Encoder ===
33
+ self.encoder = BertModel.from_pretrained("indobenchmark/indobert-base-p2")
34
+
35
+ # === GRU Decoder ===
36
+ self.input_proj = nn.Linear(pose_dim, hidden_dim)
37
+ bert_hidden = self.encoder.config.hidden_size
38
+ self.gru_cell = nn.GRUCell(hidden_dim + bert_hidden, hidden_dim)
39
+ self.dropout = nn.Dropout(0.3)
40
+
41
+ self.fc_pose = nn.Linear(hidden_dim, pose_dim)
42
+ self.fc_conf = nn.Linear(hidden_dim, NUM_KEYPOINTS)
43
+ self.output_scale = 1.0
44
+
45
+ def forward(self, input_ids, attention_mask=None, target_pose=None, teacher_forcing_ratio=TEACHER_FORCING_RATIO):
46
+ B = input_ids.size(0)
47
+ pose_outputs = []
48
+ conf_outputs = []
49
+ input_pose = torch.zeros(B, self.pose_dim).to(input_ids.device)
50
+
51
+ # === BERT Encoding ===
52
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
53
+ context = encoder_outputs.last_hidden_state[:, 0, :] # [CLS] token
54
+
55
+ h = torch.zeros(B, self.hidden_dim).to(input_ids.device)
56
+
57
+ for t in range(self.target_len):
58
+ use_teacher = self.training and target_pose is not None and torch.rand(1).item() < teacher_forcing_ratio
59
+ if use_teacher and t > 0:
60
+ input_pose = target_pose[:, t - 1, :]
61
+ elif t > 0:
62
+ input_pose = pose_outputs[-1].squeeze(1).detach()
63
+
64
+ pose_emb = self.input_proj(input_pose)
65
+ gru_input = torch.cat([pose_emb, context], dim=-1)
66
+ h = self.gru_cell(gru_input, h)
67
+ h = self.dropout(h)
68
+
69
+ pred_pose = self.fc_pose(h) * self.output_scale
70
+ pred_conf = torch.sigmoid(self.fc_conf(h))
71
+
72
+ pose_outputs.append(pred_pose.unsqueeze(1))
73
+ conf_outputs.append(pred_conf.unsqueeze(1))
74
+ input_pose = pred_pose.detach()
75
+
76
+ return torch.cat(pose_outputs, dim=1), torch.cat(conf_outputs, dim=1)
77
+
78
+ # ===== METRICS =====
79
+ def mpjpe(pred, target, mask=None):
80
+ # Shapes: (B, T, POSE_DIM)
81
+ pred = pred.view(pred.size(0), pred.size(1), NUM_KEYPOINTS, 3)
82
+ target = target.view(target.size(0), target.size(1), NUM_KEYPOINTS, 3)
83
+
84
+ error = torch.norm(pred - target, dim=3) # (B, T, NUM_KEYPOINTS)
85
+
86
+ if mask is not None:
87
+ mask = mask.view(pred.size(0), pred.size(1), NUM_KEYPOINTS) # (B, T, K)
88
+ masked_error = error * mask
89
+ return masked_error.sum() / (mask.sum() + 1e-8)
90
+ else:
91
+ return error.mean()
92
+
93
+
94
+ def per_joint_mpjpe(pred, target, mask=None):
95
+ pred = pred.view(-1, NUM_KEYPOINTS, 3)
96
+ target = target.view(-1, NUM_KEYPOINTS, 3)
97
+ error = torch.norm(pred - target, dim=2) # (B*T, K)
98
+
99
+ if mask is not None:
100
+ mask = mask.view(-1, NUM_KEYPOINTS) # (B*T, K)
101
+ masked_error = error * mask
102
+ joint_means = masked_error.sum(dim=0) / (mask.sum(dim=0) + 1e-8)
103
+ return joint_means.cpu().numpy()
104
+ else:
105
+ return error.mean(dim=0).cpu().numpy()
106
+
107
+ def pose_velocity(pose_seq):
108
+ # pose_seq shape is assumed to be (B, T, POSE_DIM)
109
+ # Calculate difference along the time dimension (dim=1)
110
+ diffs = pose_seq[:, 1:, :] - pose_seq[:, :-1, :]
111
+ # Reshape to (B, T-1, NUM_KEYPOINTS, 3) to get per-joint velocity
112
+ diffs = diffs.view(diffs.size(0), diffs.size(1), NUM_KEYPOINTS, 3)
113
+ # Norm over coordinate dimension (dim=3), then mean over batch and time
114
+ return torch.norm(diffs, dim=3).mean().item()
115
+
116
+
117
+ def cosine_similarity(pred, target):
118
+ # pred and target are (B, T, POSE_DIM)
119
+ pred = pred.view(-1, POSE_DIM).cpu().numpy()
120
+ target = target.view(-1, POSE_DIM).cpu().numpy()
121
+ # Cosine similarity is usually calculated per sample or per timestep.
122
+ # Calculating on flattened data across batch and time might not be meaningful.
123
+ # Returning a scalar mean of pairwise similarities could be an alternative.
124
+ # For simplicity, calculating similarity of flattened arrays.
125
+ if np.linalg.norm(pred) == 0 or np.linalg.norm(target) == 0:
126
+ return 0.0 # Handle zero vectors
127
+ return 1 - cosine(pred.flatten(), target.flatten())
128
+
129
+ def dtw_distance(pred, target):
130
+ # pred and target are (B, T, POSE_DIM)
131
+ # DTW is typically computed sequence-wise (T, POSE_DIM)
132
+ # Computing on the first sample of the batch as an example
133
+ pred_seq = pred[0].view(-1, POSE_DIM).cpu().numpy()
134
+ target_seq = target[0].view(-1, POSE_DIM).cpu().numpy()
135
+ # Use Euclidean distance as the distance metric for DTW
136
+ distance, _ = fastdtw(pred_seq, target_seq, dist=lambda a, b: np.linalg.norm(a - b))
137
+ return distance