akin23 commited on
Commit
fd83866
·
verified ·
1 Parent(s): df89477

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +43 -33
src/facerender/animate.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import cv2
3
  import yaml
4
  import tarfile
 
5
  import numpy as np
6
  import warnings
7
  from skimage import img_as_ubyte
@@ -128,36 +129,45 @@ class AnimateFromCoeff():
128
 
129
  return checkpoint.get('epoch', 0)
130
 
131
- def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
-
134
- # .tar içeriği kontrolü
135
- if checkpoint_path.endswith(".tar"):
136
- try:
137
- with tarfile.open(checkpoint_path, "r") as tar:
138
- members = tar.getnames()
139
- if not any(name.startswith("storages") for name in members):
140
- print("⚠️ 'storages' klasörü .tar dosyasında bulunamadı. Devam ediliyor...")
141
- else:
142
- print("✔️ 'storages' bulundu.")
143
- except Exception as e:
144
- print(f"Tar kontrol hatası: {e}")
145
-
146
- # Eğer klasörse ve archive/data.pkl varsa
147
- if os.path.isdir(checkpoint_path):
148
- archive_path = os.path.join(checkpoint_path, "archive", "data.pkl")
149
- if os.path.isfile(archive_path):
150
- checkpoint_path = archive_path
151
-
152
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
153
-
154
- if mapping is not None and 'mapping' in checkpoint:
155
- mapping.load_state_dict(checkpoint['mapping'])
156
- if discriminator is not None and 'discriminator' in checkpoint:
157
- discriminator.load_state_dict(checkpoint['discriminator'])
158
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
159
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
160
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
161
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
162
-
163
- return checkpoint.get('epoch', 0)
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
  import yaml
4
  import tarfile
5
+ import tempfile
6
  import numpy as np
7
  import warnings
8
  from skimage import img_as_ubyte
 
129
 
130
  return checkpoint.get('epoch', 0)
131
 
132
+
133
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
134
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
135
+
136
+ # 1) Eğer .tar ile bitiyorsa, önce geçici bir klasöre aç
137
+ if checkpoint_path.endswith(".tar"):
138
+ tmpdir = tempfile.mkdtemp()
139
+ with tarfile.open(checkpoint_path, "r") as tar:
140
+ tar.extractall(path=tmpdir)
141
+ # içinden ilk .pth dosyasını bul ve path'i ona yönlendir
142
+ found = False
143
+ for root, _, files in os.walk(tmpdir):
144
+ for fname in files:
145
+ if fname.endswith(".pth"):
146
+ checkpoint_path = os.path.join(root, fname)
147
+ found = True
148
+ break
149
+ if found:
150
+ break
151
+ if not found:
152
+ raise FileNotFoundError(f"{checkpoint_path} içinde .pth dosyası bulunamadı.")
153
+
154
+ # 2) Eğer doğrudan klasör olarak yüklediysen de içindeki archive/data.pkl’e bak
155
+ if os.path.isdir(checkpoint_path):
156
+ possible = os.path.join(checkpoint_path, "archive", "data.pkl")
157
+ if os.path.isfile(possible):
158
+ checkpoint_path = possible
159
+
160
+ # 3) Artık checkpoint_path kesin .pth veya .safetensors uzantılı bir dosya:
161
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
162
+
163
+ # 4) State dict’leri yüklüyoruz
164
+ if mapping is not None and 'mapping' in checkpoint:
165
+ mapping.load_state_dict(checkpoint['mapping'])
166
+ if discriminator is not None and 'discriminator' in checkpoint:
167
+ discriminator.load_state_dict(checkpoint['discriminator'])
168
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
169
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
170
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
171
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
172
+
173
+ return checkpoint.get('epoch', 0)