akin23 commited on
Commit
df89477
·
verified ·
1 Parent(s): 812d5d7

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +30 -21
src/facerender/animate.py CHANGED
@@ -68,7 +68,7 @@ class AnimateFromCoeff():
68
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
69
 
70
  if sadtalker_path['mappingnet_checkpoint'] is not None:
71
- self._load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
72
  else:
73
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
74
 
@@ -129,26 +129,35 @@ class AnimateFromCoeff():
129
  return checkpoint.get('epoch', 0)
130
 
131
  def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Eğer dosya bir klasörse ve archive/data.pkl varsa, orayı kullan
135
- if os.path.isdir(checkpoint_path):
136
- archive_path = os.path.join(checkpoint_path, "archive", "data.pkl")
137
- if os.path.isfile(archive_path):
138
- checkpoint_path = archive_path
139
-
140
- # Checkpoint yükle
141
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
142
-
143
- if mapping is not None and 'mapping' in checkpoint:
144
- mapping.load_state_dict(checkpoint['mapping'])
145
- if discriminator is not None and 'discriminator' in checkpoint:
146
- discriminator.load_state_dict(checkpoint['discriminator'])
147
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
148
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
149
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
150
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
151
-
152
- return checkpoint.get('epoch', 0)
153
 
 
 
 
 
 
 
 
 
154
 
 
 
68
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
69
 
70
  if sadtalker_path['mappingnet_checkpoint'] is not None:
71
+ self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
72
  else:
73
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
74
 
 
129
  return checkpoint.get('epoch', 0)
130
 
131
  def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
+
134
+ # .tar içeriği kontrolü
135
+ if checkpoint_path.endswith(".tar"):
136
+ try:
137
+ with tarfile.open(checkpoint_path, "r") as tar:
138
+ members = tar.getnames()
139
+ if not any(name.startswith("storages") for name in members):
140
+ print("⚠️ 'storages' klasörü .tar dosyasında bulunamadı. Devam ediliyor...")
141
+ else:
142
+ print("✔️ 'storages' bulundu.")
143
+ except Exception as e:
144
+ print(f"Tar kontrol hatası: {e}")
145
+
146
+ # Eğer klasörse ve archive/data.pkl varsa
147
+ if os.path.isdir(checkpoint_path):
148
+ archive_path = os.path.join(checkpoint_path, "archive", "data.pkl")
149
+ if os.path.isfile(archive_path):
150
+ checkpoint_path = archive_path
151
 
152
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if mapping is not None and 'mapping' in checkpoint:
155
+ mapping.load_state_dict(checkpoint['mapping'])
156
+ if discriminator is not None and 'discriminator' in checkpoint:
157
+ discriminator.load_state_dict(checkpoint['discriminator'])
158
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
159
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
160
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
161
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
162
 
163
+ return checkpoint.get('epoch', 0)