Spaces:
Running
Running
import os | |
import cv2 | |
import yaml | |
import tarfile | |
import tempfile | |
import numpy as np | |
import warnings | |
from skimage import img_as_ubyte | |
import safetensors | |
import safetensors.torch | |
warnings.filterwarnings('ignore') | |
import imageio | |
import torch | |
import torchvision | |
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector | |
from src.facerender.modules.mapping import MappingNet | |
from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator | |
from src.facerender.modules.make_animation import make_animation | |
from pydub import AudioSegment | |
from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list | |
from src.utils.paste_pic import paste_pic | |
from src.utils.videoio import save_video_with_watermark | |
try: | |
import webui # in webui | |
in_webui = True | |
except ImportError: | |
in_webui = False | |
class AnimateFromCoeff: | |
def __init__(self, sadtalker_path, device): | |
with open(sadtalker_path['facerender_yaml']) as f: | |
config = yaml.safe_load(f) | |
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], | |
**config['model_params']['common_params']) | |
kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], | |
**config['model_params']['common_params']) | |
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], | |
**config['model_params']['common_params']) | |
mapping = MappingNet(**config['model_params']['mapping_params']) | |
generator.to(device) | |
kp_extractor.to(device) | |
he_estimator.to(device) | |
mapping.to(device) | |
for param in generator.parameters(): | |
param.requires_grad = False | |
for param in kp_extractor.parameters(): | |
param.requires_grad = False | |
for param in he_estimator.parameters(): | |
param.requires_grad = False | |
for param in mapping.parameters(): | |
param.requires_grad = False | |
# FaceVid2Vid checkpoint yükleme | |
if 'checkpoint' in sadtalker_path: | |
self.load_cpk_facevid2vid_safetensor( | |
sadtalker_path['checkpoint'], | |
kp_detector=kp_extractor, | |
generator=generator, | |
he_estimator=None, | |
device=device | |
) | |
else: | |
self.load_cpk_facevid2vid( | |
sadtalker_path['free_view_checkpoint'], | |
kp_detector=kp_extractor, | |
generator=generator, | |
he_estimator=he_estimator, | |
device=device | |
) | |
# MappingNet checkpoint yükleme | |
if sadtalker_path.get('mappingnet_checkpoint') is not None: | |
self.load_cpk_mapping( | |
sadtalker_path['mappingnet_checkpoint'], | |
mapping=mapping, | |
device=device | |
) | |
else: | |
raise AttributeError("mappingnet_checkpoint path belirtmelisiniz.") | |
self.kp_extractor = kp_extractor | |
self.generator = generator | |
self.he_estimator = he_estimator | |
self.mapping = mapping | |
self.device = device | |
self.kp_extractor.eval() | |
self.generator.eval() | |
self.he_estimator.eval() | |
self.mapping.eval() | |
def load_cpk_facevid2vid_safetensor(self, checkpoint_path, | |
generator=None, kp_detector=None, | |
he_estimator=None, device="cpu"): | |
checkpoint = safetensors.torch.load_file(checkpoint_path) | |
if generator is not None: | |
state = {k.replace('generator.', ''): v | |
for k, v in checkpoint.items() if k.startswith('generator.')} | |
generator.load_state_dict(state) | |
if kp_detector is not None: | |
state = {k.replace('kp_extractor.', ''): v | |
for k, v in checkpoint.items() if k.startswith('kp_extractor.')} | |
kp_detector.load_state_dict(state) | |
if he_estimator is not None: | |
state = {k.replace('he_estimator.', ''): v | |
for k, v in checkpoint.items() if k.startswith('he_estimator.')} | |
he_estimator.load_state_dict(state) | |
return None | |
def load_cpk_facevid2vid(self, checkpoint_path, | |
generator=None, discriminator=None, | |
kp_detector=None, he_estimator=None, | |
optimizer_generator=None, optimizer_discriminator=None, | |
optimizer_kp_detector=None, optimizer_he_estimator=None, | |
device="cpu"): | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) | |
if generator is not None: | |
generator.load_state_dict(checkpoint['generator']) | |
if kp_detector is not None: | |
kp_detector.load_state_dict(checkpoint['kp_detector']) | |
if he_estimator is not None: | |
he_estimator.load_state_dict(checkpoint['he_estimator']) | |
if discriminator is not None and 'discriminator' in checkpoint: | |
discriminator.load_state_dict(checkpoint['discriminator']) | |
# Optimizeler varsa yükle | |
if optimizer_generator is not None and 'optimizer_generator' in checkpoint: | |
optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) | |
if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: | |
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) | |
if optimizer_kp_detector is not None and 'optimizer_kp_detector' in checkpoint: | |
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) | |
if optimizer_he_estimator is not None and 'optimizer_he_estimator' in checkpoint: | |
optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) | |
return checkpoint.get('epoch', 0) | |
def load_cpk_mapping(self, checkpoint_path, | |
mapping=None, discriminator=None, | |
optimizer_mapping=None, optimizer_discriminator=None, | |
device='cpu'): | |
def load_cpk_mapping(self, | |
checkpoint_path, | |
mapping=None, | |
discriminator=None, | |
optimizer_mapping=None, | |
optimizer_discriminator=None, | |
device='cpu'): | |
# 1) Eğer .tar veya .pth.tar ile bitiyorsa: | |
if checkpoint_path.endswith('.tar') or checkpoint_path.endswith('.pth.tar'): | |
tmpdir = tempfile.mkdtemp() | |
with tarfile.open(checkpoint_path, 'r') as tar: | |
tar.extractall(path=tmpdir) | |
# 1.a) Önce .pth arıyoruz, bulamazsak .pkl | |
candidate_pth = None | |
candidate_pkl = None | |
for root, _, files in os.walk(tmpdir): | |
for f in files: | |
if f.endswith('.pth') and candidate_pth is None: | |
candidate_pth = os.path.join(root, f) | |
if f.endswith('.pkl') and candidate_pkl is None: | |
candidate_pkl = os.path.join(root, f) | |
if candidate_pth: | |
break | |
if candidate_pth: | |
checkpoint_path = candidate_pth | |
elif candidate_pkl: | |
checkpoint_path = candidate_pkl | |
else: | |
raise FileNotFoundError( | |
f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi." | |
) | |
# 2) Eğer checkpoint_path bir klasörse, archive/data.pkl’e bak | |
if os.path.isdir(checkpoint_path): | |
possible = os.path.join(checkpoint_path, 'archive', 'data.pkl') | |
if os.path.isfile(possible): | |
checkpoint_path = possible | |
# 3) Torch ile gerçek dosyayı yükle | |
checkpoint = torch.load(checkpoint_path, | |
map_location=torch.device(device)) | |
# 4) State dict’leri ilgili modellere ata | |
if mapping is not None and 'mapping' in checkpoint: | |
mapping.load_state_dict(checkpoint['mapping']) | |
if discriminator is not None and 'discriminator' in checkpoint: | |
discriminator.load_state_dict(checkpoint['discriminator']) | |
if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint: | |
optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) | |
if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: | |
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) | |
# 5) Epoch bilgisi varsa dön, yoksa 0 | |
return checkpoint.get('epoch', 0) | |