akin23 commited on
Commit
5546741
·
verified ·
1 Parent(s): db61bf4

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +49 -19
src/facerender/animate.py CHANGED
@@ -153,24 +153,53 @@ class AnimateFromCoeff:
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
- # 1) Eğer bir .tar checkpoint ise, direkt torch.load ile oku
157
- if checkpoint_path.endswith('.tar'):
158
- # Burada torch, içindeki storages klasörünü de kullanarak yükleyecek
159
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
160
-
161
- else:
162
- # 2) Eğer klasör verilmişse archive/data.pkl'e bak
163
- if os.path.isdir(checkpoint_path):
164
- candidate = os.path.join(checkpoint_path, 'archive', 'data.pkl')
165
- if os.path.isfile(candidate):
166
- checkpoint_path = candidate
167
- else:
168
- raise FileNotFoundError(f"{checkpoint_path} bir klasör ama archive/data.pkl bulunamadı.")
169
-
170
- # 3) .pth veya .pkl ise yine torch.load
171
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
172
-
173
- # 4) State dict’leri ilgili objelere yükle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  if mapping is not None and 'mapping' in checkpoint:
175
  mapping.load_state_dict(checkpoint['mapping'])
176
  if discriminator is not None and 'discriminator' in checkpoint:
@@ -180,5 +209,6 @@ class AnimateFromCoeff:
180
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
181
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
182
 
183
- # 5) Epoch bilgisi varsa dön
184
  return checkpoint.get('epoch', 0)
 
 
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
+ def load_cpk_mapping(self,
157
+ checkpoint_path,
158
+ mapping=None,
159
+ discriminator=None,
160
+ optimizer_mapping=None,
161
+ optimizer_discriminator=None,
162
+ device='cpu'):
163
+ import os, tarfile, tempfile, torch
164
+
165
+ # 1) Eğer .tar veya .pth.tar ile bitiyorsa:
166
+ if checkpoint_path.endswith('.tar') or checkpoint_path.endswith('.pth.tar'):
167
+ tmpdir = tempfile.mkdtemp()
168
+ with tarfile.open(checkpoint_path, 'r') as tar:
169
+ tar.extractall(path=tmpdir)
170
+
171
+ # 1.a) Önce .pth arıyoruz, bulamazsak .pkl
172
+ candidate_pth = None
173
+ candidate_pkl = None
174
+ for root, _, files in os.walk(tmpdir):
175
+ for f in files:
176
+ if f.endswith('.pth') and candidate_pth is None:
177
+ candidate_pth = os.path.join(root, f)
178
+ if f.endswith('.pkl') and candidate_pkl is None:
179
+ candidate_pkl = os.path.join(root, f)
180
+ if candidate_pth:
181
+ break
182
+
183
+ if candidate_pth:
184
+ checkpoint_path = candidate_pth
185
+ elif candidate_pkl:
186
+ checkpoint_path = candidate_pkl
187
+ else:
188
+ raise FileNotFoundError(
189
+ f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi."
190
+ )
191
+
192
+ # 2) Eğer checkpoint_path bir klasörse, archive/data.pkl’e bak
193
+ if os.path.isdir(checkpoint_path):
194
+ possible = os.path.join(checkpoint_path, 'archive', 'data.pkl')
195
+ if os.path.isfile(possible):
196
+ checkpoint_path = possible
197
+
198
+ # 3) Torch ile gerçek dosyayı yükle
199
+ checkpoint = torch.load(checkpoint_path,
200
+ map_location=torch.device(device))
201
+
202
+ # 4) State dict’leri ilgili modellere ata
203
  if mapping is not None and 'mapping' in checkpoint:
204
  mapping.load_state_dict(checkpoint['mapping'])
205
  if discriminator is not None and 'discriminator' in checkpoint:
 
209
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
210
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
211
 
212
+ # 5) Epoch bilgisi varsa dön, yoksa 0
213
  return checkpoint.get('epoch', 0)
214
+