akin23 commited on
Commit
d4f06c4
·
verified ·
1 Parent(s): 0067cbf

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +35 -25
src/facerender/animate.py CHANGED
@@ -128,29 +128,39 @@ class AnimateFromCoeff():
128
 
129
  return checkpoint.get('epoch', 0)
130
 
131
- def _load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
-
134
- if checkpoint_path.endswith(".tar"):
135
- try:
136
- with tarfile.open(checkpoint_path, "r") as tar:
137
- extract_dir = checkpoint_path.replace(".tar", "")
138
- os.makedirs(extract_dir, exist_ok=True)
139
- tar.extractall(path=extract_dir)
140
- checkpoint_path = os.path.join(extract_dir, "data.pkl")
141
- except Exception as e:
142
- print(f"Tar açma hatası: {e}")
143
- return 0
144
-
145
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
146
-
147
- if mapping is not None and 'mapping' in checkpoint:
148
- mapping.load_state_dict(checkpoint['mapping'])
149
- if discriminator is not None and 'discriminator' in checkpoint:
150
- discriminator.load_state_dict(checkpoint['discriminator'])
151
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
152
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
153
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
154
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- return checkpoint.get('epoch', 0)
 
128
 
129
  return checkpoint.get('epoch', 0)
130
 
131
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
133
+
134
+ # Eğer dosya değil de klasörse ve içinde archive/data.pkl varsa onu yükle
135
+ if os.path.isdir(checkpoint_path):
136
+ archive_path = os.path.join(checkpoint_path, "archive", "data.pkl")
137
+ if os.path.isfile(archive_path):
138
+ print(f"📦 data.pkl bulundu: {archive_path}")
139
+ checkpoint_path = archive_path
140
+ else:
141
+ raise FileNotFoundError("archive/data.pkl bulunamadı!")
142
+
143
+ # Eğer tar dosyasıysa sadece kontrol amaçlı bakıyoruz
144
+ elif checkpoint_path.endswith(".tar"):
145
+ try:
146
+ with tarfile.open(checkpoint_path, "r") as tar:
147
+ members = tar.getnames()
148
+ if not any("storages" in name for name in members):
149
+ print("⚠️ 'storages' klasörü .tar içinde yok. Devam ediliyor...")
150
+ except Exception as e:
151
+ print(f"Tar kontrol hatası: {e}")
152
+
153
+ # Yükleme işlemi
154
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
155
+
156
+ if mapping is not None and 'mapping' in checkpoint:
157
+ mapping.load_state_dict(checkpoint['mapping'])
158
+ if discriminator is not None and 'discriminator' in checkpoint:
159
+ discriminator.load_state_dict(checkpoint['discriminator'])
160
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
161
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
162
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
163
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
164
+
165
+ return checkpoint.get('epoch', 0)
166