akin23's picture
Update src/facerender/animate.py
50da102 verified
import os
import cv2
import yaml
import tarfile
import tempfile
import numpy as np
import warnings
from skimage import img_as_ubyte
import safetensors
import safetensors.torch
warnings.filterwarnings('ignore')
import imageio
import torch
import torchvision
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
from src.facerender.modules.mapping import MappingNet
from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from src.facerender.modules.make_animation import make_animation
from pydub import AudioSegment
from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
from src.utils.paste_pic import paste_pic
from src.utils.videoio import save_video_with_watermark
try:
import webui # in webui
in_webui = True
except ImportError:
in_webui = False
class AnimateFromCoeff:
def __init__(self, sadtalker_path, device):
with open(sadtalker_path['facerender_yaml']) as f:
config = yaml.safe_load(f)
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
**config['model_params']['common_params'])
mapping = MappingNet(**config['model_params']['mapping_params'])
generator.to(device)
kp_extractor.to(device)
he_estimator.to(device)
mapping.to(device)
for param in generator.parameters():
param.requires_grad = False
for param in kp_extractor.parameters():
param.requires_grad = False
for param in he_estimator.parameters():
param.requires_grad = False
for param in mapping.parameters():
param.requires_grad = False
# FaceVid2Vid checkpoint yükleme
if 'checkpoint' in sadtalker_path:
self.load_cpk_facevid2vid_safetensor(
sadtalker_path['checkpoint'],
kp_detector=kp_extractor,
generator=generator,
he_estimator=None,
device=device
)
else:
self.load_cpk_facevid2vid(
sadtalker_path['free_view_checkpoint'],
kp_detector=kp_extractor,
generator=generator,
he_estimator=he_estimator,
device=device
)
# MappingNet checkpoint yükleme
if sadtalker_path.get('mappingnet_checkpoint') is not None:
self.load_cpk_mapping(
sadtalker_path['mappingnet_checkpoint'],
mapping=mapping,
device=device
)
else:
raise AttributeError("mappingnet_checkpoint path belirtmelisiniz.")
self.kp_extractor = kp_extractor
self.generator = generator
self.he_estimator = he_estimator
self.mapping = mapping
self.device = device
self.kp_extractor.eval()
self.generator.eval()
self.he_estimator.eval()
self.mapping.eval()
def load_cpk_facevid2vid_safetensor(self, checkpoint_path,
generator=None, kp_detector=None,
he_estimator=None, device="cpu"):
checkpoint = safetensors.torch.load_file(checkpoint_path)
if generator is not None:
state = {k.replace('generator.', ''): v
for k, v in checkpoint.items() if k.startswith('generator.')}
generator.load_state_dict(state)
if kp_detector is not None:
state = {k.replace('kp_extractor.', ''): v
for k, v in checkpoint.items() if k.startswith('kp_extractor.')}
kp_detector.load_state_dict(state)
if he_estimator is not None:
state = {k.replace('he_estimator.', ''): v
for k, v in checkpoint.items() if k.startswith('he_estimator.')}
he_estimator.load_state_dict(state)
return None
def load_cpk_facevid2vid(self, checkpoint_path,
generator=None, discriminator=None,
kp_detector=None, he_estimator=None,
optimizer_generator=None, optimizer_discriminator=None,
optimizer_kp_detector=None, optimizer_he_estimator=None,
device="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
if generator is not None:
generator.load_state_dict(checkpoint['generator'])
if kp_detector is not None:
kp_detector.load_state_dict(checkpoint['kp_detector'])
if he_estimator is not None:
he_estimator.load_state_dict(checkpoint['he_estimator'])
if discriminator is not None and 'discriminator' in checkpoint:
discriminator.load_state_dict(checkpoint['discriminator'])
# Optimizeler varsa yükle
if optimizer_generator is not None and 'optimizer_generator' in checkpoint:
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
if optimizer_kp_detector is not None and 'optimizer_kp_detector' in checkpoint:
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
if optimizer_he_estimator is not None and 'optimizer_he_estimator' in checkpoint:
optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
return checkpoint.get('epoch', 0)
def load_cpk_mapping(self, checkpoint_path,
mapping=None, discriminator=None,
optimizer_mapping=None, optimizer_discriminator=None,
device='cpu'):
def load_cpk_mapping(self,
checkpoint_path,
mapping=None,
discriminator=None,
optimizer_mapping=None,
optimizer_discriminator=None,
device='cpu'):
# 1) Eğer .tar veya .pth.tar ile bitiyorsa:
if checkpoint_path.endswith('.tar') or checkpoint_path.endswith('.pth.tar'):
tmpdir = tempfile.mkdtemp()
with tarfile.open(checkpoint_path, 'r') as tar:
tar.extractall(path=tmpdir)
# 1.a) Önce .pth arıyoruz, bulamazsak .pkl
candidate_pth = None
candidate_pkl = None
for root, _, files in os.walk(tmpdir):
for f in files:
if f.endswith('.pth') and candidate_pth is None:
candidate_pth = os.path.join(root, f)
if f.endswith('.pkl') and candidate_pkl is None:
candidate_pkl = os.path.join(root, f)
if candidate_pth:
break
if candidate_pth:
checkpoint_path = candidate_pth
elif candidate_pkl:
checkpoint_path = candidate_pkl
else:
raise FileNotFoundError(
f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi."
)
# 2) Eğer checkpoint_path bir klasörse, archive/data.pkl’e bak
if os.path.isdir(checkpoint_path):
possible = os.path.join(checkpoint_path, 'archive', 'data.pkl')
if os.path.isfile(possible):
checkpoint_path = possible
# 3) Torch ile gerçek dosyayı yükle
checkpoint = torch.load(checkpoint_path,
map_location=torch.device(device))
# 4) State dict’leri ilgili modellere ata
if mapping is not None and 'mapping' in checkpoint:
mapping.load_state_dict(checkpoint['mapping'])
if discriminator is not None and 'discriminator' in checkpoint:
discriminator.load_state_dict(checkpoint['discriminator'])
if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint:
optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
# 5) Epoch bilgisi varsa dön, yoksa 0
return checkpoint.get('epoch', 0)