akin23 commited on
Commit
1c96585
·
verified ·
1 Parent(s): fd83866

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +91 -72
src/facerender/animate.py CHANGED
@@ -7,8 +7,7 @@ import numpy as np
7
  import warnings
8
  from skimage import img_as_ubyte
9
  import safetensors
10
- import safetensors.torch
11
-
12
  warnings.filterwarnings('ignore')
13
 
14
  import imageio
@@ -18,9 +17,9 @@ import torchvision
18
  from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
19
  from src.facerender.modules.mapping import MappingNet
20
  from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
21
- from src.facerender.modules.make_animation import make_animation
22
 
23
- from pydub import AudioSegment
24
  from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
25
  from src.utils.paste_pic import paste_pic
26
  from src.utils.videoio import save_video_with_watermark
@@ -28,11 +27,11 @@ from src.utils.videoio import save_video_with_watermark
28
  try:
29
  import webui # in webui
30
  in_webui = True
31
- except:
32
  in_webui = False
33
 
34
 
35
- class AnimateFromCoeff():
36
 
37
  def __init__(self, sadtalker_path, device):
38
  with open(sadtalker_path['facerender_yaml']) as f:
@@ -60,53 +59,72 @@ class AnimateFromCoeff():
60
  for param in mapping.parameters():
61
  param.requires_grad = False
62
 
63
- if sadtalker_path is not None:
64
- if 'checkpoint' in sadtalker_path:
65
- self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
66
- else:
67
- self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
 
 
 
 
68
  else:
69
- raise AttributeError("Checkpoint should be specified for video head pose estimator.")
70
-
71
- if sadtalker_path['mappingnet_checkpoint'] is not None:
72
- self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
- raise AttributeError("Checkpoint should be specified for video head pose estimator.")
75
 
76
  self.kp_extractor = kp_extractor
77
  self.generator = generator
78
  self.he_estimator = he_estimator
79
  self.mapping = mapping
 
80
 
81
  self.kp_extractor.eval()
82
  self.generator.eval()
83
  self.he_estimator.eval()
84
  self.mapping.eval()
85
 
86
- self.device = device
87
-
88
- def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
89
- kp_detector=None, he_estimator=None,
90
- device="cpu"):
91
 
92
  checkpoint = safetensors.torch.load_file(checkpoint_path)
93
 
94
  if generator is not None:
95
- x_generator = {k.replace('generator.', ''): v for k, v in checkpoint.items() if 'generator' in k}
96
- generator.load_state_dict(x_generator)
 
97
  if kp_detector is not None:
98
- x_generator = {k.replace('kp_extractor.', ''): v for k, v in checkpoint.items() if 'kp_extractor' in k}
99
- kp_detector.load_state_dict(x_generator)
 
100
  if he_estimator is not None:
101
- x_generator = {k.replace('he_estimator.', ''): v for k, v in checkpoint.items() if 'he_estimator' in k}
102
- he_estimator.load_state_dict(x_generator)
 
103
 
104
  return None
105
 
106
- def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
107
- kp_detector=None, he_estimator=None, optimizer_generator=None,
108
- optimizer_discriminator=None, optimizer_kp_detector=None,
109
- optimizer_he_estimator=None, device="cpu"):
 
 
110
 
111
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
112
 
@@ -118,6 +136,7 @@ class AnimateFromCoeff():
118
  he_estimator.load_state_dict(checkpoint['he_estimator'])
119
  if discriminator is not None and 'discriminator' in checkpoint:
120
  discriminator.load_state_dict(checkpoint['discriminator'])
 
121
  if optimizer_generator is not None and 'optimizer_generator' in checkpoint:
122
  optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
123
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
@@ -129,45 +148,45 @@ class AnimateFromCoeff():
129
 
130
  return checkpoint.get('epoch', 0)
131
 
132
-
133
- def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
134
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
135
-
136
- # 1) Eğer .tar ile bitiyorsa, önce geçici bir klasöre aç
137
- if checkpoint_path.endswith(".tar"):
138
- tmpdir = tempfile.mkdtemp()
139
- with tarfile.open(checkpoint_path, "r") as tar:
140
- tar.extractall(path=tmpdir)
141
- # içinden ilk .pth dosyasını bul ve path'i ona yönlendir
142
- found = False
143
- for root, _, files in os.walk(tmpdir):
144
- for fname in files:
145
- if fname.endswith(".pth"):
146
- checkpoint_path = os.path.join(root, fname)
147
- found = True
 
 
148
  break
149
- if found:
150
- break
151
- if not found:
152
- raise FileNotFoundError(f"{checkpoint_path} içinde .pth dosyası bulunamadı.")
153
-
154
- # 2) Eğer doğrudan klasör olarak yüklediysen de içindeki archive/data.pkl’e bak
155
- if os.path.isdir(checkpoint_path):
156
- possible = os.path.join(checkpoint_path, "archive", "data.pkl")
157
- if os.path.isfile(possible):
158
- checkpoint_path = possible
159
-
160
- # 3) Artık checkpoint_path kesin .pth veya .safetensors uzantılı bir dosya:
161
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
162
-
163
- # 4) State dict’leri yüklüyoruz
164
- if mapping is not None and 'mapping' in checkpoint:
165
- mapping.load_state_dict(checkpoint['mapping'])
166
- if discriminator is not None and 'discriminator' in checkpoint:
167
- discriminator.load_state_dict(checkpoint['discriminator'])
168
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
169
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
170
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
171
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
172
-
173
- return checkpoint.get('epoch', 0)
 
7
  import warnings
8
  from skimage import img_as_ubyte
9
  import safetensors
10
+ import safetensors.torch
 
11
  warnings.filterwarnings('ignore')
12
 
13
  import imageio
 
17
  from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
18
  from src.facerender.modules.mapping import MappingNet
19
  from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
20
+ from src.facerender.modules.make_animation import make_animation
21
 
22
+ from pydub import AudioSegment
23
  from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
24
  from src.utils.paste_pic import paste_pic
25
  from src.utils.videoio import save_video_with_watermark
 
27
  try:
28
  import webui # in webui
29
  in_webui = True
30
+ except ImportError:
31
  in_webui = False
32
 
33
 
34
+ class AnimateFromCoeff:
35
 
36
  def __init__(self, sadtalker_path, device):
37
  with open(sadtalker_path['facerender_yaml']) as f:
 
59
  for param in mapping.parameters():
60
  param.requires_grad = False
61
 
62
+ # FaceVid2Vid checkpoint yükleme
63
+ if 'checkpoint' in sadtalker_path:
64
+ self.load_cpk_facevid2vid_safetensor(
65
+ sadtalker_path['checkpoint'],
66
+ kp_detector=kp_extractor,
67
+ generator=generator,
68
+ he_estimator=None,
69
+ device=device
70
+ )
71
  else:
72
+ self.load_cpk_facevid2vid(
73
+ sadtalker_path['free_view_checkpoint'],
74
+ kp_detector=kp_extractor,
75
+ generator=generator,
76
+ he_estimator=he_estimator,
77
+ device=device
78
+ )
79
+
80
+ # MappingNet checkpoint yükleme
81
+ if sadtalker_path.get('mappingnet_checkpoint') is not None:
82
+ self.load_cpk_mapping(
83
+ sadtalker_path['mappingnet_checkpoint'],
84
+ mapping=mapping,
85
+ device=device
86
+ )
87
  else:
88
+ raise AttributeError("mappingnet_checkpoint path belirtmelisiniz.")
89
 
90
  self.kp_extractor = kp_extractor
91
  self.generator = generator
92
  self.he_estimator = he_estimator
93
  self.mapping = mapping
94
+ self.device = device
95
 
96
  self.kp_extractor.eval()
97
  self.generator.eval()
98
  self.he_estimator.eval()
99
  self.mapping.eval()
100
 
101
+ def load_cpk_facevid2vid_safetensor(self, checkpoint_path,
102
+ generator=None, kp_detector=None,
103
+ he_estimator=None, device="cpu"):
 
 
104
 
105
  checkpoint = safetensors.torch.load_file(checkpoint_path)
106
 
107
  if generator is not None:
108
+ state = {k.replace('generator.', ''): v
109
+ for k, v in checkpoint.items() if k.startswith('generator.')}
110
+ generator.load_state_dict(state)
111
  if kp_detector is not None:
112
+ state = {k.replace('kp_extractor.', ''): v
113
+ for k, v in checkpoint.items() if k.startswith('kp_extractor.')}
114
+ kp_detector.load_state_dict(state)
115
  if he_estimator is not None:
116
+ state = {k.replace('he_estimator.', ''): v
117
+ for k, v in checkpoint.items() if k.startswith('he_estimator.')}
118
+ he_estimator.load_state_dict(state)
119
 
120
  return None
121
 
122
+ def load_cpk_facevid2vid(self, checkpoint_path,
123
+ generator=None, discriminator=None,
124
+ kp_detector=None, he_estimator=None,
125
+ optimizer_generator=None, optimizer_discriminator=None,
126
+ optimizer_kp_detector=None, optimizer_he_estimator=None,
127
+ device="cpu"):
128
 
129
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
130
 
 
136
  he_estimator.load_state_dict(checkpoint['he_estimator'])
137
  if discriminator is not None and 'discriminator' in checkpoint:
138
  discriminator.load_state_dict(checkpoint['discriminator'])
139
+ # Optimizeler varsa yükle
140
  if optimizer_generator is not None and 'optimizer_generator' in checkpoint:
141
  optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
142
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
 
148
 
149
  return checkpoint.get('epoch', 0)
150
 
151
+ def load_cpk_mapping(self, checkpoint_path,
152
+ mapping=None, discriminator=None,
153
+ optimizer_mapping=None, optimizer_discriminator=None,
154
+ device='cpu'):
155
+
156
+ # 1) .tar ise içeriği aç ve .pth bul
157
+ if checkpoint_path.endswith(".tar"):
158
+ tmpdir = tempfile.mkdtemp()
159
+ with tarfile.open(checkpoint_path, "r") as tar:
160
+ tar.extractall(path=tmpdir)
161
+ found = False
162
+ for root, _, files in os.walk(tmpdir):
163
+ for fname in files:
164
+ if fname.endswith(".pth"):
165
+ checkpoint_path = os.path.join(root, fname)
166
+ found = True
167
+ break
168
+ if found:
169
  break
170
+ if not found:
171
+ raise FileNotFoundError(f"{checkpoint_path} içinde .pth dosyası bulunamadı.")
172
+
173
+ # 2) Klasör yüklendiyse archive/data.pkl’e bak
174
+ if os.path.isdir(checkpoint_path):
175
+ possible = os.path.join(checkpoint_path, "archive", "data.pkl")
176
+ if os.path.isfile(possible):
177
+ checkpoint_path = possible
178
+
179
+ # 3) checkpoint’i yükle
180
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
181
+
182
+ # 4) State dict’leri ata
183
+ if mapping is not None and 'mapping' in checkpoint:
184
+ mapping.load_state_dict(checkpoint['mapping'])
185
+ if discriminator is not None and 'discriminator' in checkpoint:
186
+ discriminator.load_state_dict(checkpoint['discriminator'])
187
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
188
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
189
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
190
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
191
+
192
+ return checkpoint.get('epoch', 0)