import os import cv2 import yaml import tarfile import tempfile import numpy as np import warnings from skimage import img_as_ubyte import safetensors import safetensors.torch warnings.filterwarnings('ignore') import imageio import torch import torchvision from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector from src.facerender.modules.mapping import MappingNet from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator from src.facerender.modules.make_animation import make_animation from pydub import AudioSegment from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list from src.utils.paste_pic import paste_pic from src.utils.videoio import save_video_with_watermark try: import webui # in webui in_webui = True except ImportError: in_webui = False class AnimateFromCoeff: def __init__(self, sadtalker_path, device): with open(sadtalker_path['facerender_yaml']) as f: config = yaml.safe_load(f) generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], **config['model_params']['common_params']) mapping = MappingNet(**config['model_params']['mapping_params']) generator.to(device) kp_extractor.to(device) he_estimator.to(device) mapping.to(device) for param in generator.parameters(): param.requires_grad = False for param in kp_extractor.parameters(): param.requires_grad = False for param in he_estimator.parameters(): param.requires_grad = False for param in mapping.parameters(): param.requires_grad = False # FaceVid2Vid checkpoint yükleme if 'checkpoint' in sadtalker_path: self.load_cpk_facevid2vid_safetensor( sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None, device=device ) else: self.load_cpk_facevid2vid( sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator, device=device ) # MappingNet checkpoint yükleme if sadtalker_path.get('mappingnet_checkpoint') is not None: self.load_cpk_mapping( sadtalker_path['mappingnet_checkpoint'], mapping=mapping, device=device ) else: raise AttributeError("mappingnet_checkpoint path belirtmelisiniz.") self.kp_extractor = kp_extractor self.generator = generator self.he_estimator = he_estimator self.mapping = mapping self.device = device self.kp_extractor.eval() self.generator.eval() self.he_estimator.eval() self.mapping.eval() def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, kp_detector=None, he_estimator=None, device="cpu"): checkpoint = safetensors.torch.load_file(checkpoint_path) if generator is not None: state = {k.replace('generator.', ''): v for k, v in checkpoint.items() if k.startswith('generator.')} generator.load_state_dict(state) if kp_detector is not None: state = {k.replace('kp_extractor.', ''): v for k, v in checkpoint.items() if k.startswith('kp_extractor.')} kp_detector.load_state_dict(state) if he_estimator is not None: state = {k.replace('he_estimator.', ''): v for k, v in checkpoint.items() if k.startswith('he_estimator.')} he_estimator.load_state_dict(state) return None def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, kp_detector=None, he_estimator=None, optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_he_estimator=None, device="cpu"): checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) if generator is not None: generator.load_state_dict(checkpoint['generator']) if kp_detector is not None: kp_detector.load_state_dict(checkpoint['kp_detector']) if he_estimator is not None: he_estimator.load_state_dict(checkpoint['he_estimator']) if discriminator is not None and 'discriminator' in checkpoint: discriminator.load_state_dict(checkpoint['discriminator']) # Optimizeler varsa yükle if optimizer_generator is not None and 'optimizer_generator' in checkpoint: optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) if optimizer_kp_detector is not None and 'optimizer_kp_detector' in checkpoint: optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) if optimizer_he_estimator is not None and 'optimizer_he_estimator' in checkpoint: optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) return checkpoint.get('epoch', 0) def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): # 1) Eğer .tar veya .pth.tar ile bitiyorsa: if checkpoint_path.endswith('.tar') or checkpoint_path.endswith('.pth.tar'): tmpdir = tempfile.mkdtemp() with tarfile.open(checkpoint_path, 'r') as tar: tar.extractall(path=tmpdir) # 1.a) Önce .pth arıyoruz, bulamazsak .pkl candidate_pth = None candidate_pkl = None for root, _, files in os.walk(tmpdir): for f in files: if f.endswith('.pth') and candidate_pth is None: candidate_pth = os.path.join(root, f) if f.endswith('.pkl') and candidate_pkl is None: candidate_pkl = os.path.join(root, f) if candidate_pth: break if candidate_pth: checkpoint_path = candidate_pth elif candidate_pkl: checkpoint_path = candidate_pkl else: raise FileNotFoundError( f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi." ) # 2) Eğer checkpoint_path bir klasörse, archive/data.pkl’e bak if os.path.isdir(checkpoint_path): possible = os.path.join(checkpoint_path, 'archive', 'data.pkl') if os.path.isfile(possible): checkpoint_path = possible # 3) Torch ile gerçek dosyayı yükle checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) # 4) State dict’leri ilgili modellere ata if mapping is not None and 'mapping' in checkpoint: mapping.load_state_dict(checkpoint['mapping']) if discriminator is not None and 'discriminator' in checkpoint: discriminator.load_state_dict(checkpoint['discriminator']) if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint: optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) # 5) Epoch bilgisi varsa dön, yoksa 0 return checkpoint.get('epoch', 0)