akin23 commited on
Commit
4ac7833
·
verified ·
1 Parent(s): a3bff50

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +51 -178
src/facerender/animate.py CHANGED
@@ -10,13 +10,10 @@ import safetensors.torch
10
 
11
  warnings.filterwarnings('ignore')
12
 
13
-
14
-
15
  import imageio
16
  import torch
17
  import torchvision
18
 
19
-
20
  from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
21
  from src.facerender.modules.mapping import MappingNet
22
  from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
@@ -33,46 +30,47 @@ try:
33
  except:
34
  in_webui = False
35
 
 
36
  class AnimateFromCoeff():
37
 
38
  def __init__(self, sadtalker_path, device):
39
-
40
  with open(sadtalker_path['facerender_yaml']) as f:
41
  config = yaml.safe_load(f)
42
 
43
  generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
44
- **config['model_params']['common_params'])
45
  kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
46
- **config['model_params']['common_params'])
47
  he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
48
- **config['model_params']['common_params'])
49
  mapping = MappingNet(**config['model_params']['mapping_params'])
50
 
51
  generator.to(device)
52
  kp_extractor.to(device)
53
  he_estimator.to(device)
54
  mapping.to(device)
 
55
  for param in generator.parameters():
56
  param.requires_grad = False
57
  for param in kp_extractor.parameters():
58
- param.requires_grad = False
59
  for param in he_estimator.parameters():
60
  param.requires_grad = False
61
  for param in mapping.parameters():
62
  param.requires_grad = False
63
 
64
  if sadtalker_path is not None:
65
- if 'checkpoint' in sadtalker_path: # use safe tensor
66
  self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
67
  else:
68
  self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
69
  else:
70
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
71
 
72
- if sadtalker_path['mappingnet_checkpoint'] is not None:
73
  self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
74
  else:
75
- raise AttributeError("Checkpoint should be specified for video head pose estimator.")
76
 
77
  self.kp_extractor = kp_extractor
78
  self.generator = generator
@@ -83,201 +81,76 @@ class AnimateFromCoeff():
83
  self.generator.eval()
84
  self.he_estimator.eval()
85
  self.mapping.eval()
86
-
87
  self.device = device
88
-
89
  def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
90
- kp_detector=None, he_estimator=None,
91
- device="cpu"):
92
 
93
  checkpoint = safetensors.torch.load_file(checkpoint_path)
94
 
95
  if generator is not None:
96
- x_generator = {}
97
- for k,v in checkpoint.items():
98
- if 'generator' in k:
99
- x_generator[k.replace('generator.', '')] = v
100
  generator.load_state_dict(x_generator)
101
  if kp_detector is not None:
102
- x_generator = {}
103
- for k,v in checkpoint.items():
104
- if 'kp_extractor' in k:
105
- x_generator[k.replace('kp_extractor.', '')] = v
106
  kp_detector.load_state_dict(x_generator)
107
  if he_estimator is not None:
108
- x_generator = {}
109
- for k,v in checkpoint.items():
110
- if 'he_estimator' in k:
111
- x_generator[k.replace('he_estimator.', '')] = v
112
  he_estimator.load_state_dict(x_generator)
113
-
114
  return None
115
 
116
  def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
117
- kp_detector=None, he_estimator=None, optimizer_generator=None,
118
- optimizer_discriminator=None, optimizer_kp_detector=None,
119
- optimizer_he_estimator=None, device="cpu"):
 
120
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
121
  if generator is not None:
122
  generator.load_state_dict(checkpoint['generator'])
123
  if kp_detector is not None:
124
  kp_detector.load_state_dict(checkpoint['kp_detector'])
125
  if he_estimator is not None:
126
  he_estimator.load_state_dict(checkpoint['he_estimator'])
127
- if discriminator is not None:
128
- try:
129
- discriminator.load_state_dict(checkpoint['discriminator'])
130
- except:
131
- print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
132
- if optimizer_generator is not None:
133
  optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
134
- if optimizer_discriminator is not None:
135
- try:
136
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
137
- except RuntimeError as e:
138
- print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
139
- if optimizer_kp_detector is not None:
140
  optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
141
- if optimizer_he_estimator is not None:
142
  optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
143
 
144
- return checkpoint['epoch']
145
-
146
-
147
-
148
- def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
149
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
150
-
151
- # Eğer .tar dosyasıysa içeriğini kontrol et
152
- if checkpoint_path.endswith(".tar"):
153
- try:
154
- with tarfile.open(checkpoint_path, "r") as tar:
155
- members = tar.getnames()
156
- if not any(name.startswith("storages") for name in members):
157
- print("⚠️ 'storages' klasörü .tar dosyasında bulunamadı. Devam ediliyor...")
158
- else:
159
- print("✔️ 'storages' bulundu.")
160
- except Exception as e:
161
- print(f"Tar kontrol hatası: {e}")
162
-
163
- # Checkpoint yükle
164
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
165
-
166
- # Mapping yükleme
167
- if mapping is not None and 'mapping' in checkpoint:
168
- mapping.load_state_dict(checkpoint['mapping'])
169
-
170
- # Diğer parametreler varsa
171
- if discriminator is not None and 'discriminator' in checkpoint:
172
- discriminator.load_state_dict(checkpoint['discriminator'])
173
 
174
- if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
175
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
176
-
177
- if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
178
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
179
-
180
- return checkpoint.get('epoch', 0)
181
-
182
-
183
- def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
184
-
185
- source_image=x['source_image'].type(torch.FloatTensor)
186
- source_semantics=x['source_semantics'].type(torch.FloatTensor)
187
- target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
188
- source_image=source_image.to(self.device)
189
- source_semantics=source_semantics.to(self.device)
190
- target_semantics=target_semantics.to(self.device)
191
- if 'yaw_c_seq' in x:
192
- yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
193
- yaw_c_seq = x['yaw_c_seq'].to(self.device)
194
- else:
195
- yaw_c_seq = None
196
- if 'pitch_c_seq' in x:
197
- pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
198
- pitch_c_seq = x['pitch_c_seq'].to(self.device)
199
- else:
200
- pitch_c_seq = None
201
- if 'roll_c_seq' in x:
202
- roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
203
- roll_c_seq = x['roll_c_seq'].to(self.device)
204
- else:
205
- roll_c_seq = None
206
-
207
- frame_num = x['frame_num']
208
-
209
- predictions_video = make_animation(source_image, source_semantics, target_semantics,
210
- self.generator, self.kp_extractor, self.he_estimator, self.mapping,
211
- yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
212
-
213
- predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
214
- predictions_video = predictions_video[:frame_num]
215
-
216
- video = []
217
- for idx in range(predictions_video.shape[0]):
218
- image = predictions_video[idx]
219
- image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
220
- video.append(image)
221
- result = img_as_ubyte(video)
222
-
223
- ### the generated video is 256x256, so we keep the aspect ratio,
224
- original_size = crop_info[0]
225
- if original_size:
226
- result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
227
-
228
- video_name = x['video_name'] + '.mp4'
229
- path = os.path.join(video_save_dir, 'temp_'+video_name)
230
-
231
- imageio.mimsave(path, result, fps=float(25))
232
-
233
- av_path = os.path.join(video_save_dir, video_name)
234
- return_path = av_path
235
-
236
- audio_path = x['audio_path']
237
- audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
238
- new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
239
- start_time = 0
240
- # cog will not keep the .mp3 filename
241
- sound = AudioSegment.from_file(audio_path)
242
- frames = frame_num
243
- end_time = start_time + frames*1/25*1000
244
- word1=sound.set_frame_rate(16000)
245
- word = word1[start_time:end_time]
246
- word.export(new_audio_path, format="wav")
247
-
248
- save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
249
- print(f'The generated video is named {video_save_dir}/{video_name}')
250
-
251
- if 'full' in preprocess.lower():
252
- # only add watermark to the full image.
253
- video_name_full = x['video_name'] + '_full.mp4'
254
- full_video_path = os.path.join(video_save_dir, video_name_full)
255
- return_path = full_video_path
256
- paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
257
- print(f'The generated video is named {video_save_dir}/{video_name_full}')
258
- else:
259
- full_video_path = av_path
260
-
261
- #### paste back then enhancers
262
- if enhancer:
263
- video_name_enhancer = x['video_name'] + '_enhanced.mp4'
264
- enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
265
- av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
266
- return_path = av_path_enhancer
267
 
 
268
  try:
269
- enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
270
- imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
271
- except:
272
- enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
273
- imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
274
-
275
- save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
276
- print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
277
- os.remove(enhanced_path)
278
 
279
- os.remove(path)
280
- os.remove(new_audio_path)
281
 
282
- return return_path
 
 
 
 
 
 
 
283
 
 
 
10
 
11
  warnings.filterwarnings('ignore')
12
 
 
 
13
  import imageio
14
  import torch
15
  import torchvision
16
 
 
17
  from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
18
  from src.facerender.modules.mapping import MappingNet
19
  from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
 
30
  except:
31
  in_webui = False
32
 
33
+
34
  class AnimateFromCoeff():
35
 
36
  def __init__(self, sadtalker_path, device):
 
37
  with open(sadtalker_path['facerender_yaml']) as f:
38
  config = yaml.safe_load(f)
39
 
40
  generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
41
+ **config['model_params']['common_params'])
42
  kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
43
+ **config['model_params']['common_params'])
44
  he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
45
+ **config['model_params']['common_params'])
46
  mapping = MappingNet(**config['model_params']['mapping_params'])
47
 
48
  generator.to(device)
49
  kp_extractor.to(device)
50
  he_estimator.to(device)
51
  mapping.to(device)
52
+
53
  for param in generator.parameters():
54
  param.requires_grad = False
55
  for param in kp_extractor.parameters():
56
+ param.requires_grad = False
57
  for param in he_estimator.parameters():
58
  param.requires_grad = False
59
  for param in mapping.parameters():
60
  param.requires_grad = False
61
 
62
  if sadtalker_path is not None:
63
+ if 'checkpoint' in sadtalker_path:
64
  self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
65
  else:
66
  self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
67
  else:
68
  raise AttributeError("Checkpoint should be specified for video head pose estimator.")
69
 
70
+ if sadtalker_path['mappingnet_checkpoint'] is not None:
71
  self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
72
  else:
73
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
74
 
75
  self.kp_extractor = kp_extractor
76
  self.generator = generator
 
81
  self.generator.eval()
82
  self.he_estimator.eval()
83
  self.mapping.eval()
84
+
85
  self.device = device
86
+
87
  def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
88
+ kp_detector=None, he_estimator=None,
89
+ device="cpu"):
90
 
91
  checkpoint = safetensors.torch.load_file(checkpoint_path)
92
 
93
  if generator is not None:
94
+ x_generator = {k.replace('generator.', ''): v for k, v in checkpoint.items() if 'generator' in k}
 
 
 
95
  generator.load_state_dict(x_generator)
96
  if kp_detector is not None:
97
+ x_generator = {k.replace('kp_extractor.', ''): v for k, v in checkpoint.items() if 'kp_extractor' in k}
 
 
 
98
  kp_detector.load_state_dict(x_generator)
99
  if he_estimator is not None:
100
+ x_generator = {k.replace('he_estimator.', ''): v for k, v in checkpoint.items() if 'he_estimator' in k}
 
 
 
101
  he_estimator.load_state_dict(x_generator)
102
+
103
  return None
104
 
105
  def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
106
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
107
+ optimizer_discriminator=None, optimizer_kp_detector=None,
108
+ optimizer_he_estimator=None, device="cpu"):
109
+
110
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
111
+
112
  if generator is not None:
113
  generator.load_state_dict(checkpoint['generator'])
114
  if kp_detector is not None:
115
  kp_detector.load_state_dict(checkpoint['kp_detector'])
116
  if he_estimator is not None:
117
  he_estimator.load_state_dict(checkpoint['he_estimator'])
118
+ if discriminator is not None and 'discriminator' in checkpoint:
119
+ discriminator.load_state_dict(checkpoint['discriminator'])
120
+ if optimizer_generator is not None and 'optimizer_generator' in checkpoint:
 
 
 
121
  optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
122
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
123
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
124
+ if optimizer_kp_detector is not None and 'optimizer_kp_detector' in checkpoint:
 
 
 
125
  optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
126
+ if optimizer_he_estimator is not None and 'optimizer_he_estimator' in checkpoint:
127
  optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
128
 
129
+ return checkpoint.get('epoch', 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
132
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ if checkpoint_path.endswith(".tar"):
135
  try:
136
+ with tarfile.open(checkpoint_path, "r") as tar:
137
+ members = tar.getnames()
138
+ if not any(name.startswith("storages") for name in members):
139
+ print("⚠️ 'storages' klasörü .tar dosyasında bulunamadı. Devam ediliyor...")
140
+ else:
141
+ print("✔️ 'storages' bulundu.")
142
+ except Exception as e:
143
+ print(f"Tar kontrol hatası: {e}")
 
144
 
145
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
146
 
147
+ if mapping is not None and 'mapping' in checkpoint:
148
+ mapping.load_state_dict(checkpoint['mapping'])
149
+ if discriminator is not None and 'discriminator' in checkpoint:
150
+ discriminator.load_state_dict(checkpoint['discriminator'])
151
+ if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
152
+ optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
153
+ if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
154
+ optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
155
 
156
+ return checkpoint.get('epoch', 0)