akin23 commited on
Commit
0067cbf
·
verified ·
1 Parent(s): aa02566

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +26 -30
src/facerender/animate.py CHANGED
@@ -68,7 +68,7 @@ class AnimateFromCoeff():
68
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
69
 
70
  if sadtalker_path['mappingnet_checkpoint'] is not None:
71
- self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
72
  else:
73
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
74
 
@@ -128,33 +128,29 @@ class AnimateFromCoeff():
128
 
129
  return checkpoint.get('epoch', 0)
130
 
131
- def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
-
134
- # Eğer .tar ise aç ve klasör haline getir
135
- if checkpoint_path.endswith(".tar"):
136
- try:
137
- print(f"Extracting tar file: {checkpoint_path}")
138
- with tarfile.open(checkpoint_path, "r") as tar:
139
- extract_dir = checkpoint_path.replace(".tar", "")
140
- os.makedirs(extract_dir, exist_ok=True)
141
- tar.extractall(path=extract_dir)
142
- checkpoint_path = os.path.join(extract_dir, "data.pkl")
143
- except Exception as e:
144
- print(f"Tar açma hatası: {e}")
145
- return 0 # fail gracefully
146
-
147
- # Checkpoint (data.pkl) yükle
148
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
149
-
150
- if mapping is not None and 'mapping' in checkpoint:
151
- mapping.load_state_dict(checkpoint['mapping'])
152
- if discriminator is not None and 'discriminator' in checkpoint:
153
- discriminator.load_state_dict(checkpoint['discriminator'])
154
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
155
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
156
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
157
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
158
-
159
- return checkpoint.get('epoch', 0)
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
69
 
70
  if sadtalker_path['mappingnet_checkpoint'] is not None:
71
+ self._load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
72
  else:
73
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
74
 
 
128
 
129
  return checkpoint.get('epoch', 0)
130
 
131
+ def _load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ if checkpoint_path.endswith(".tar"):
135
+ try:
136
+ with tarfile.open(checkpoint_path, "r") as tar:
137
+ extract_dir = checkpoint_path.replace(".tar", "")
138
+ os.makedirs(extract_dir, exist_ok=True)
139
+ tar.extractall(path=extract_dir)
140
+ checkpoint_path = os.path.join(extract_dir, "data.pkl")
141
+ except Exception as e:
142
+ print(f"Tar açma hatası: {e}")
143
+ return 0
144
+
145
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
146
+
147
+ if mapping is not None and 'mapping' in checkpoint:
148
+ mapping.load_state_dict(checkpoint['mapping'])
149
+ if discriminator is not None and 'discriminator' in checkpoint:
150
+ discriminator.load_state_dict(checkpoint['discriminator'])
151
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
152
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
153
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
154
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
155
+
156
+ return checkpoint.get('epoch', 0)