akin23 commited on
Commit
db61bf4
·
verified ·
1 Parent(s): cfa9dce

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +17 -35
src/facerender/animate.py CHANGED
@@ -153,42 +153,24 @@ class AnimateFromCoeff:
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
- # Eğer .tar ile bitiyorsa, önce aç ve içinden .pth veya .pkl ara
157
  if checkpoint_path.endswith('.tar'):
158
- tmpdir = tempfile.mkdtemp()
159
- with tarfile.open(checkpoint_path, 'r') as tar:
160
- tar.extractall(path=tmpdir)
161
-
162
- found_pth = None
163
- found_pkl = None
164
- for root, _, files in os.walk(tmpdir):
165
- for fname in files:
166
- if fname.endswith('.pth') and found_pth is None:
167
- found_pth = os.path.join(root, fname)
168
- if fname.endswith('.pkl') and found_pkl is None:
169
- found_pkl = os.path.join(root, fname)
170
- if found_pth:
171
- break
172
-
173
- if found_pth:
174
- checkpoint_path = found_pth
175
- elif found_pkl:
176
- checkpoint_path = found_pkl
177
- else:
178
- raise FileNotFoundError(
179
- f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi."
180
- )
181
-
182
- # Eğer bir klasör yoluna geldi ise (nadiren kullanılır), archive altındaki data.pkl’e bak
183
- if os.path.isdir(checkpoint_path):
184
- possible = os.path.join(checkpoint_path, 'archive', 'data.pkl')
185
- if os.path.isfile(possible):
186
- checkpoint_path = possible
187
-
188
- # Artık checkpoint_path kesin .pth veya .pkl
189
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
190
 
191
- # State dict’leri ilgili modellere yükle
 
 
 
 
 
 
 
 
 
 
 
 
192
  if mapping is not None and 'mapping' in checkpoint:
193
  mapping.load_state_dict(checkpoint['mapping'])
194
  if discriminator is not None and 'discriminator' in checkpoint:
@@ -198,5 +180,5 @@ class AnimateFromCoeff:
198
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
199
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
200
 
201
- # Epoch bilgisi varsa döndür, yoksa 0
202
  return checkpoint.get('epoch', 0)
 
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
+ # 1) Eğer bir .tar checkpoint ise, direkt torch.load ile oku
157
  if checkpoint_path.endswith('.tar'):
158
+ # Burada torch, içindeki storages klasörünü de kullanarak yükleyecek
159
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ else:
162
+ # 2) Eğer klasör verilmişse archive/data.pkl'e bak
163
+ if os.path.isdir(checkpoint_path):
164
+ candidate = os.path.join(checkpoint_path, 'archive', 'data.pkl')
165
+ if os.path.isfile(candidate):
166
+ checkpoint_path = candidate
167
+ else:
168
+ raise FileNotFoundError(f"{checkpoint_path} bir klasör ama archive/data.pkl bulunamadı.")
169
+
170
+ # 3) .pth veya .pkl ise yine torch.load
171
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
172
+
173
+ # 4) State dict’leri ilgili objelere yükle
174
  if mapping is not None and 'mapping' in checkpoint:
175
  mapping.load_state_dict(checkpoint['mapping'])
176
  if discriminator is not None and 'discriminator' in checkpoint:
 
180
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
181
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
182
 
183
+ # 5) Epoch bilgisi varsa dön
184
  return checkpoint.get('epoch', 0)