Spaces:
Runtime error
Runtime error
Commit
·
992b457
1
Parent(s):
8dd4264
add
Browse files
app.py
CHANGED
|
@@ -154,8 +154,8 @@ class BaseTrainer(object):
|
|
| 154 |
self.args.vae_layer = 2
|
| 155 |
self.args.vae_length = 256
|
| 156 |
self.args.vae_test_dim = 106
|
| 157 |
-
self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
|
| 158 |
-
other_tools.load_checkpoints(self.vq_model_face, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", args.e_name)
|
| 159 |
|
| 160 |
|
| 161 |
vq_type = self.args.vqvae_type
|
|
@@ -260,7 +260,7 @@ class BaseTrainer(object):
|
|
| 260 |
self.args.vae_length = 240
|
| 261 |
|
| 262 |
|
| 263 |
-
self.vq_model_face.eval()
|
| 264 |
self.vq_model_upper.eval()
|
| 265 |
self.vq_model_hands.eval()
|
| 266 |
self.vq_model_lower.eval()
|
|
@@ -375,7 +375,7 @@ class BaseTrainer(object):
|
|
| 375 |
tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std
|
| 376 |
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
|
| 377 |
|
| 378 |
-
latent_face_top = self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
|
| 379 |
latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
|
| 380 |
latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
|
| 381 |
latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
|
|
|
|
| 154 |
self.args.vae_layer = 2
|
| 155 |
self.args.vae_length = 256
|
| 156 |
self.args.vae_test_dim = 106
|
| 157 |
+
# self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank)
|
| 158 |
+
# other_tools.load_checkpoints(self.vq_model_face, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", args.e_name)
|
| 159 |
|
| 160 |
|
| 161 |
vq_type = self.args.vqvae_type
|
|
|
|
| 260 |
self.args.vae_length = 240
|
| 261 |
|
| 262 |
|
| 263 |
+
# self.vq_model_face.eval()
|
| 264 |
self.vq_model_upper.eval()
|
| 265 |
self.vq_model_hands.eval()
|
| 266 |
self.vq_model_lower.eval()
|
|
|
|
| 375 |
tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std
|
| 376 |
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
|
| 377 |
|
| 378 |
+
latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
|
| 379 |
latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
|
| 380 |
latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
|
| 381 |
latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
|