akin23 commited on
Commit
aa02566
·
verified ·
1 Parent(s): 4ac7833

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +28 -24
src/facerender/animate.py CHANGED
@@ -129,28 +129,32 @@ class AnimateFromCoeff():
129
  return checkpoint.get('epoch', 0)
130
 
131
  def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
-
134
- if checkpoint_path.endswith(".tar"):
135
- try:
136
- with tarfile.open(checkpoint_path, "r") as tar:
137
- members = tar.getnames()
138
- if not any(name.startswith("storages") for name in members):
139
- print("⚠️ 'storages' klasörü .tar dosyasında bulunamadı. Devam ediliyor...")
140
- else:
141
- print("✔️ 'storages' bulundu.")
142
- except Exception as e:
143
- print(f"Tar kontrol hatası: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
146
-
147
- if mapping is not None and 'mapping' in checkpoint:
148
- mapping.load_state_dict(checkpoint['mapping'])
149
- if discriminator is not None and 'discriminator' in checkpoint:
150
- discriminator.load_state_dict(checkpoint['discriminator'])
151
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
152
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
153
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
154
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
155
-
156
- return checkpoint.get('epoch', 0)
 
129
  return checkpoint.get('epoch', 0)
130
 
131
  def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
+
134
+ # Eğer .tar ise aç ve klasör haline getir
135
+ if checkpoint_path.endswith(".tar"):
136
+ try:
137
+ print(f"Extracting tar file: {checkpoint_path}")
138
+ with tarfile.open(checkpoint_path, "r") as tar:
139
+ extract_dir = checkpoint_path.replace(".tar", "")
140
+ os.makedirs(extract_dir, exist_ok=True)
141
+ tar.extractall(path=extract_dir)
142
+ checkpoint_path = os.path.join(extract_dir, "data.pkl")
143
+ except Exception as e:
144
+ print(f"Tar açma hatası: {e}")
145
+ return 0 # fail gracefully
146
+
147
+ # Checkpoint (data.pkl) yükle
148
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
149
+
150
+ if mapping is not None and 'mapping' in checkpoint:
151
+ mapping.load_state_dict(checkpoint['mapping'])
152
+ if discriminator is not None and 'discriminator' in checkpoint:
153
+ discriminator.load_state_dict(checkpoint['discriminator'])
154
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
155
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
156
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
157
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
158
+
159
+ return checkpoint.get('epoch', 0)
160