| import argparse |
| import json |
| import os |
| import time |
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torchvision |
| import tyro |
| import yaml |
| from loguru import logger |
| from PIL import Image |
|
|
| from external.human_matting import StyleMatteEngine as HumanMattingEngine |
| from external.landmark_detection.FaceBoxesV2.faceboxes_detector import \ |
| FaceBoxesDetector |
| from external.landmark_detection.infer_image import Alignment |
| from external.vgghead_detector import VGGHeadDetector |
| from vhap.config.base import BaseTrackingConfig |
| from vhap.export_as_nerf_dataset import (NeRFDatasetWriter, |
| TrackedFLAMEDatasetWriter, split_json) |
| from vhap.model.tracker import GlobalTracker |
|
|
| |
| ERROR_CODE = {'FailedToDetect': 1, 'FailedToOptimize': 2, 'FailedToExport': 3} |
|
|
|
|
| def expand_bbox(bbox, scale=1.1): |
| """Expands the bounding box by a given scale.""" |
| xmin, ymin, xmax, ymax = bbox.unbind(dim=-1) |
| center_x, center_y = (xmin + xmax) / 2, (ymin + ymax) / 2 |
| extension_size = torch.sqrt((ymax - ymin) * (xmax - xmin)) * scale |
| x_min_expanded = center_x - extension_size / 2 |
| x_max_expanded = center_x + extension_size / 2 |
| y_min_expanded = center_y - extension_size / 2 |
| y_max_expanded = center_y + extension_size / 2 |
| return torch.stack( |
| [x_min_expanded, y_min_expanded, x_max_expanded, y_max_expanded], |
| dim=-1) |
|
|
|
|
| def load_config(src_folder: Path): |
| """Load configuration from the given source folder.""" |
| config_file_path = src_folder / 'config.yml' |
| if not config_file_path.exists(): |
| src_folder = sorted( |
| src_folder.iterdir())[-1] |
| config_file_path = src_folder / 'config.yml' |
| assert config_file_path.exists(), f'File not found: {config_file_path}' |
|
|
| config_data = yaml.load(config_file_path.read_text(), Loader=yaml.Loader) |
| return src_folder, config_data |
|
|
|
|
| class FlameTrackingSingleImage: |
| """Class for tracking and processing a single image.""" |
| def __init__( |
| self, |
| output_dir, |
| alignment_model_path='./pretrain_model/68_keypoints_model.pkl', |
| vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd', |
| human_matting_path='./pretrain_model/matting/stylematte_synth.pt', |
| facebox_model_path='./pretrain_model/FaceBoxesV2.pth', |
| detect_iris_landmarks=False): |
|
|
| logger.info(f'Output Directory: {output_dir}') |
|
|
| start_time = time.time() |
| logger.info('Loading Pre-trained Models...') |
|
|
| self.output_dir = output_dir |
| self.output_preprocess = os.path.join(output_dir, 'preprocess') |
| self.output_tracking = os.path.join(output_dir, 'tracking') |
| self.output_export = os.path.join(output_dir, 'export') |
| self.device = 'cuda:0' |
|
|
| |
| assert os.path.exists( |
| alignment_model_path), f'{alignment_model_path} does not exist!' |
| args = self._parse_args() |
| args.model_path = alignment_model_path |
| self.alignment = Alignment(args, |
| alignment_model_path, |
| dl_framework='pytorch', |
| device_ids=[0]) |
|
|
| |
| assert os.path.exists( |
| vgghead_model_path), f'{vgghead_model_path} does not exist!' |
| self.vgghead_encoder = VGGHeadDetector( |
| device=self.device, vggheadmodel_path=vgghead_model_path) |
|
|
| |
| assert os.path.exists( |
| human_matting_path), f'{human_matting_path} does not exist!' |
| self.matting_engine = HumanMattingEngine( |
| device=self.device, human_matting_path=human_matting_path) |
|
|
| |
| assert os.path.exists( |
| facebox_model_path), f'{facebox_model_path} does not exist!' |
| self.detector = FaceBoxesDetector('FaceBoxes', facebox_model_path, |
| True, self.device) |
|
|
| self.detect_iris_landmarks_flag = detect_iris_landmarks |
| if self.detect_iris_landmarks_flag: |
| from fdlite import FaceDetection, FaceLandmark, IrisLandmark |
| self.iris_detect_faces = FaceDetection() |
| self.iris_detect_face_landmarks = FaceLandmark() |
| self.iris_detect_iris_landmarks = IrisLandmark() |
|
|
| end_time = time.time() |
| torch.cuda.empty_cache() |
| logger.info(f'Finished Loading Pre-trained Models. Time: ' |
| f'{end_time - start_time:.2f}s') |
|
|
| def _parse_args(self): |
| parser = argparse.ArgumentParser(description='Evaluation script') |
| parser.add_argument('--output_dir', |
| type=str, |
| help='Output directory', |
| default='output') |
| parser.add_argument('--config_name', |
| type=str, |
| help='Configuration name', |
| default='alignment') |
| return parser.parse_args() |
|
|
| def preprocess(self, input_image_path): |
| """Preprocess the input image for tracking.""" |
| if not os.path.exists(input_image_path): |
| logger.warning(f'{input_image_path} does not exist!') |
| return ERROR_CODE['FailedToDetect'] |
|
|
| start_time = time.time() |
| logger.info('Starting Preprocessing...') |
| name_list = [] |
| frame_index = 0 |
|
|
| |
| frame = torchvision.io.read_image(input_image_path) |
| try: |
| _, frame_bbox, _ = self.vgghead_encoder(frame, frame_index) |
| except Exception: |
| logger.error('Failed to detect face') |
| return ERROR_CODE['FailedToDetect'] |
|
|
| if frame_bbox is None: |
| logger.error('Failed to detect face') |
| return ERROR_CODE['FailedToDetect'] |
|
|
| |
| name_list.append('00000.png') |
| frame_bbox = expand_bbox(frame_bbox, scale=1.65).long() |
|
|
| |
| cropped_frame = torchvision.transforms.functional.crop( |
| frame, |
| top=frame_bbox[1], |
| left=frame_bbox[0], |
| height=frame_bbox[3] - frame_bbox[1], |
| width=frame_bbox[2] - frame_bbox[0]) |
| cropped_frame = torchvision.transforms.functional.resize( |
| cropped_frame, (1024, 1024), antialias=True) |
|
|
| |
| cropped_frame, mask = self.matting_engine(cropped_frame / 255.0, |
| return_type='matting', |
| background_rgb=1.0) |
| cropped_frame = cropped_frame.cpu() * 255.0 |
| saved_image = np.round(cropped_frame.cpu().permute( |
| 1, 2, 0).numpy()).astype(np.uint8)[:, :, (2, 1, 0)] |
|
|
| |
| self.sub_output_dir = os.path.join( |
| self.output_preprocess, |
| os.path.splitext(os.path.basename(input_image_path))[0]) |
| output_image_dir = os.path.join(self.sub_output_dir, 'images') |
| output_mask_dir = os.path.join(self.sub_output_dir, 'mask') |
| output_alpha_map_dir = os.path.join(self.sub_output_dir, 'alpha_maps') |
|
|
| os.makedirs(output_image_dir, exist_ok=True) |
| os.makedirs(output_mask_dir, exist_ok=True) |
| os.makedirs(output_alpha_map_dir, exist_ok=True) |
|
|
| |
| cv2.imwrite(os.path.join(output_image_dir, name_list[frame_index]), |
| saved_image) |
| cv2.imwrite(os.path.join(output_mask_dir, name_list[frame_index]), |
| np.array((mask.cpu() * 255.0)).astype(np.uint8)) |
| cv2.imwrite( |
| os.path.join(output_alpha_map_dir, |
| name_list[frame_index]).replace('.png', '.jpg'), |
| (np.ones_like(saved_image) * 255).astype(np.uint8)) |
|
|
| |
| detections, _ = self.detector.detect(saved_image, 0.8, 1) |
| for idx, detection in enumerate(detections): |
| x1_ori, y1_ori = detection[2], detection[3] |
| x2_ori, y2_ori = x1_ori + detection[4], y1_ori + detection[5] |
|
|
| scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180 |
| center_w, center_h = (x1_ori + x2_ori) / 2, (y1_ori + y2_ori) / 2 |
| scale, center_w, center_h = float(scale), float(center_w), float( |
| center_h) |
|
|
| face_landmarks = self.alignment.analyze(saved_image, scale, |
| center_w, center_h) |
|
|
| |
| normalized_landmarks = np.zeros((face_landmarks.shape[0], 3)) |
| normalized_landmarks[:, :2] = face_landmarks / 1024 |
|
|
| landmark_output_dir = os.path.join(self.sub_output_dir, 'landmark2d') |
| os.makedirs(landmark_output_dir, exist_ok=True) |
|
|
| landmark_data = { |
| 'bounding_box': [], |
| 'face_landmark_2d': normalized_landmarks[None, ...], |
| } |
|
|
| landmark_path = os.path.join(landmark_output_dir, 'landmarks.npz') |
| np.savez(landmark_path, **landmark_data) |
|
|
| if self.detect_iris_landmarks_flag: |
| self._detect_iris_landmarks( |
| os.path.join(output_image_dir, name_list[frame_index])) |
|
|
| end_time = time.time() |
| torch.cuda.empty_cache() |
| logger.info( |
| f'Finished Processing Image. Time: {end_time - start_time:.2f}s') |
|
|
| return 0 |
|
|
| def optimize(self): |
| """Optimize the tracking model using configuration data.""" |
| start_time = time.time() |
| logger.info('Starting Optimization...') |
|
|
| tyro.extras.set_accent_color('bright_yellow') |
| config_data = tyro.cli(BaseTrackingConfig) |
|
|
| config_data.data.sequence = self.sub_output_dir.split('/')[-1] |
| config_data.data.root_folder = Path( |
| os.path.dirname(self.sub_output_dir)) |
|
|
| if not os.path.exists(self.sub_output_dir): |
| logger.error(f'Failed to load {self.sub_output_dir}') |
| return ERROR_CODE['FailedToOptimize'] |
|
|
| config_data.exp.output_folder = Path(self.output_tracking) |
| tracker = GlobalTracker(config_data) |
| tracker.optimize() |
|
|
| end_time = time.time() |
| torch.cuda.empty_cache() |
| logger.info( |
| f'Finished Optimization. Time: {end_time - start_time:.2f}s') |
|
|
| return 0 |
|
|
| def _detect_iris_landmarks(self, image_path): |
| """Detect iris landmarks in the given image.""" |
| from fdlite import face_detection_to_roi, iris_roi_from_face_landmarks |
|
|
| img = Image.open(image_path) |
| img_size = (1024, 1024) |
|
|
| face_detections = self.iris_detect_faces(img) |
| if len(face_detections) != 1: |
| logger.warning('Empty iris landmarks') |
| else: |
| face_detection = face_detections[0] |
| try: |
| face_roi = face_detection_to_roi(face_detection, img_size) |
| except ValueError: |
| logger.warning('Empty iris landmarks') |
| return |
|
|
| face_landmarks = self.iris_detect_face_landmarks(img, face_roi) |
| if len(face_landmarks) == 0: |
| logger.warning('Empty iris landmarks') |
| return |
|
|
| iris_rois = iris_roi_from_face_landmarks(face_landmarks, img_size) |
|
|
| if len(iris_rois) != 2: |
| logger.warning('Empty iris landmarks') |
| return |
|
|
| landmarks = [] |
| for iris_roi in iris_rois[::-1]: |
| try: |
| iris_landmarks = self.iris_detect_iris_landmarks( |
| img, iris_roi).iris[0:1] |
| except np.linalg.LinAlgError: |
| logger.warning('Failed to get iris landmarks') |
| break |
|
|
| |
| for landmark in iris_landmarks: |
| landmarks.append(landmark.x * 1024) |
| landmarks.append(landmark.y * 1024) |
|
|
| landmark_data = {'00000.png': landmarks} |
| json.dump( |
| landmark_data, |
| open( |
| os.path.join(self.sub_output_dir, 'landmark2d', |
| 'iris.json'), 'w')) |
|
|
| def export(self): |
| """Export the tracking results to configured folder.""" |
| logger.info(f'Beginning export from {self.output_tracking}') |
| start_time = time.time() |
| if not os.path.exists(self.output_tracking): |
| logger.error(f'Failed to load {self.output_tracking}') |
| return ERROR_CODE['FailedToExport'], 'Failed' |
|
|
| src_folder = Path(self.output_tracking) |
| tgt_folder = Path(self.output_export, |
| self.sub_output_dir.split('/')[-1]) |
| src_folder, config_data = load_config(src_folder) |
|
|
| nerf_writer = NeRFDatasetWriter(config_data.data, tgt_folder, None, |
| None, 'white') |
| nerf_writer.write() |
|
|
| flame_writer = TrackedFLAMEDatasetWriter(config_data.model, |
| src_folder, |
| tgt_folder, |
| mode='param', |
| epoch=-1) |
| flame_writer.write() |
|
|
| split_json(tgt_folder) |
|
|
| end_time = time.time() |
| torch.cuda.empty_cache() |
| logger.info(f'Finished Export. Time: {end_time - start_time:.2f}s') |
|
|
| return 0, str(tgt_folder) |
|
|