Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +24 -0
- LHM/__init__.py +15 -0
- LHM/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/__pycache__/launch.cpython-310.pyc +0 -0
- LHM/datasets/__init__.py +16 -0
- LHM/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/datasets/__pycache__/cam_utils.cpython-310.pyc +0 -0
- LHM/datasets/__pycache__/mixer.cpython-310.pyc +0 -0
- LHM/datasets/base.py +70 -0
- LHM/datasets/cam_utils.py +205 -0
- LHM/datasets/mixer.py +75 -0
- LHM/launch.py +37 -0
- LHM/losses/__init__.py +20 -0
- LHM/losses/ball_loss.py +54 -0
- LHM/losses/offset_loss.py +52 -0
- LHM/losses/perceptual.py +70 -0
- LHM/losses/pixelwise.py +58 -0
- LHM/losses/tvloss.py +55 -0
- LHM/models/ESRGANer_utils.py +482 -0
- LHM/models/__init__.py +24 -0
- LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc +0 -0
- LHM/models/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/__pycache__/arcface_utils.cpython-310.pyc +0 -0
- LHM/models/__pycache__/embedder.cpython-310.pyc +0 -0
- LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc +0 -0
- LHM/models/__pycache__/transformer.cpython-310.pyc +0 -0
- LHM/models/__pycache__/transformer_dit.cpython-310.pyc +0 -0
- LHM/models/__pycache__/utils.cpython-310.pyc +0 -0
- LHM/models/arcface_utils.py +360 -0
- LHM/models/block.py +124 -0
- LHM/models/discriminator.py +120 -0
- LHM/models/embedder.py +37 -0
- LHM/models/encoders/__init__.py +15 -0
- LHM/models/encoders/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc +0 -0
- LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc +0 -0
- LHM/models/encoders/dino_wrapper.py +68 -0
- LHM/models/encoders/dinov2/__init__.py +15 -0
- LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__init__.py +4 -0
- LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/backbones.py +166 -0
- LHM/models/encoders/dinov2/hub/classifiers.py +268 -0
- LHM/models/encoders/dinov2/hub/depth/__init__.py +7 -0
- LHM/models/encoders/dinov2/hub/depth/decode_heads.py +747 -0
- LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py +351 -0
- LHM/models/encoders/dinov2/hub/depth/ops.py +28 -0
- LHM/models/encoders/dinov2/hub/depthers.py +246 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,27 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/characters_images/000001.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/videos/scene_000000/bkgd_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/videos/scene_000000/smplx_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
diffsynth/models/__pycache__/sd3_text_encoder.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
diffsynth/models/__pycache__/sd_unet.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
diffsynth/models/__pycache__/sdxl_unet.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
diffsynth/models/__pycache__/sdxl_unet.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
diffsynth/models/__pycache__/svd_unet.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
diffsynth/models/__pycache__/svd_unet.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
diffsynth/models/__pycache__/svd_unet.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
engine/pose_estimation/third-party/ViTPose/figures/Throughput.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
engine/pose_estimation/third-party/ViTPose/mmpose/.mim/demo/resources/demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
engine/pose_estimation/third-party/ViTPose/mmpose/.mim/demo/resources/demo_coco.gif filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
pretrained_models/dense_sample_points/1_40000.ply filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
pretrained_models/dense_sample_points/1_60000.ply filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
pretrained_models/dense_sample_points/1_80000.ply filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
pretrained_models/gagatracker/vgghead/vgg_heads_l.trcd filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
pretrained_models/huggingface/models--3DAIGC--LHM-1B-HF/blobs/59dc25167d1d72d57fb068445b96e2343ab550b649e9999765200502d03171b9 filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
pretrained_models/human_model_files/smplx/smplx_uv/smplx_uv.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
pretrained_models/sapiens/pretrained/checkpoints/sapiens_1b/sapiens_1b_epoch_173_torchscript.pt2 filter=lfs diff=lfs merge=lfs -text
|
LHM/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Empty
|
LHM/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
LHM/__pycache__/launch.cpython-310.pyc
ADDED
|
Binary file (723 Bytes). View file
|
|
|
LHM/datasets/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from .mixer import MixerDataset
|
LHM/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (227 Bytes). View file
|
|
|
LHM/datasets/__pycache__/cam_utils.cpython-310.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
LHM/datasets/__pycache__/mixer.cpython-310.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
LHM/datasets/base.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Organization : Alibaba XR-Lab
|
| 3 |
+
# @Author : Peihao Li & Lingteng Qiu & Xiaodong Gu & Qi Zuo
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @Time : 2025-03-10 18:47:56
|
| 6 |
+
# @Function : dataset base
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import pdb
|
| 10 |
+
import traceback
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from megfile import smart_exists, smart_open, smart_path_join
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaseDataset(torch.utils.data.Dataset, ABC):
|
| 20 |
+
def __init__(self, root_dirs: str, meta_path: str):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.root_dirs = root_dirs
|
| 23 |
+
self.uids = self._load_uids(meta_path)
|
| 24 |
+
|
| 25 |
+
def __len__(self):
|
| 26 |
+
return len(self.uids)
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def inner_get_item(self, idx):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
try:
|
| 34 |
+
return self.inner_get_item(idx)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
traceback.print_exc()
|
| 37 |
+
print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
|
| 38 |
+
# raise e
|
| 39 |
+
return self.__getitem__((idx + 1) % self.__len__())
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def _load_uids(meta_path: str):
|
| 43 |
+
# meta_path is a json file
|
| 44 |
+
if meta_path == None:
|
| 45 |
+
uids = []
|
| 46 |
+
else:
|
| 47 |
+
with open(meta_path, "r") as f:
|
| 48 |
+
uids = json.load(f)
|
| 49 |
+
|
| 50 |
+
return uids
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def _load_rgba_image(file_path, bg_color: float = 1.0):
|
| 54 |
+
"""Load and blend RGBA image to RGB with certain background, 0-1 scaled"""
|
| 55 |
+
rgba = np.array(Image.open(smart_open(file_path, "rb")))
|
| 56 |
+
rgba = torch.from_numpy(rgba).float() / 255.0
|
| 57 |
+
rgba = rgba.permute(2, 0, 1).unsqueeze(0)
|
| 58 |
+
rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (
|
| 59 |
+
1 - rgba[:, 3:, :, :]
|
| 60 |
+
)
|
| 61 |
+
# rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
|
| 62 |
+
return rgb
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _locate_datadir(root_dirs, uid, locator: str):
|
| 66 |
+
for root_dir in root_dirs:
|
| 67 |
+
datadir = smart_path_join(root_dir, uid, locator)
|
| 68 |
+
if smart_exists(datadir):
|
| 69 |
+
return root_dir
|
| 70 |
+
raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
|
LHM/datasets/cam_utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
R: (N, 3, 3)
|
| 21 |
+
T: (N, 3)
|
| 22 |
+
E: (N, 4, 4)
|
| 23 |
+
vector: (N, 3)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
|
| 28 |
+
"""
|
| 29 |
+
Compose the standard form extrinsic matrix from R and T.
|
| 30 |
+
Batched I/O.
|
| 31 |
+
"""
|
| 32 |
+
RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
|
| 33 |
+
return compose_extrinsic_RT(RT)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def compose_extrinsic_RT(RT: torch.Tensor):
|
| 37 |
+
"""
|
| 38 |
+
Compose the standard form extrinsic matrix from RT.
|
| 39 |
+
Batched I/O.
|
| 40 |
+
"""
|
| 41 |
+
return torch.cat([
|
| 42 |
+
RT,
|
| 43 |
+
torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
|
| 44 |
+
], dim=1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def decompose_extrinsic_R_T(E: torch.Tensor):
|
| 48 |
+
"""
|
| 49 |
+
Decompose the standard extrinsic matrix into R and T.
|
| 50 |
+
Batched I/O.
|
| 51 |
+
"""
|
| 52 |
+
RT = decompose_extrinsic_RT(E)
|
| 53 |
+
return RT[:, :, :3], RT[:, :, 3]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def decompose_extrinsic_RT(E: torch.Tensor):
|
| 57 |
+
"""
|
| 58 |
+
Decompose the standard extrinsic matrix into RT.
|
| 59 |
+
Batched I/O.
|
| 60 |
+
"""
|
| 61 |
+
return E[:, :3, :]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
|
| 65 |
+
assert normed_dist_to_center is not None
|
| 66 |
+
pivotal_pose = compose_extrinsic_RT(poses[:1])
|
| 67 |
+
dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
|
| 68 |
+
if normed_dist_to_center == 'auto' else normed_dist_to_center
|
| 69 |
+
|
| 70 |
+
# compute camera norm (new version)
|
| 71 |
+
canonical_camera_extrinsics = torch.tensor([[
|
| 72 |
+
[1, 0, 0, 0],
|
| 73 |
+
[0, 0, -1, -dist_to_center],
|
| 74 |
+
[0, 1, 0, 0],
|
| 75 |
+
[0, 0, 0, 1],
|
| 76 |
+
]], dtype=torch.float32)
|
| 77 |
+
pivotal_pose_inv = torch.inverse(pivotal_pose)
|
| 78 |
+
camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
|
| 79 |
+
|
| 80 |
+
# normalize all views
|
| 81 |
+
poses = compose_extrinsic_RT(poses)
|
| 82 |
+
poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
|
| 83 |
+
poses = decompose_extrinsic_RT(poses)
|
| 84 |
+
|
| 85 |
+
if ret_transform:
|
| 86 |
+
return poses, camera_norm_matrix.squeeze(dim=0)
|
| 87 |
+
return poses
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
|
| 91 |
+
"""
|
| 92 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
| 93 |
+
Return batched fx, fy, cx, cy
|
| 94 |
+
"""
|
| 95 |
+
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
|
| 96 |
+
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
|
| 97 |
+
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
|
| 98 |
+
fx, fy = fx / width, fy / height
|
| 99 |
+
cx, cy = cx / width, cy / height
|
| 100 |
+
return fx, fy, cx, cy
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
|
| 104 |
+
"""
|
| 105 |
+
RT: (N, 3, 4)
|
| 106 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
| 107 |
+
"""
|
| 108 |
+
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
| 109 |
+
return torch.cat([
|
| 110 |
+
RT.reshape(-1, 12),
|
| 111 |
+
fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
|
| 112 |
+
], dim=-1)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
|
| 116 |
+
"""
|
| 117 |
+
RT: (N, 3, 4)
|
| 118 |
+
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
|
| 119 |
+
"""
|
| 120 |
+
E = compose_extrinsic_RT(RT)
|
| 121 |
+
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
|
| 122 |
+
I = torch.stack([
|
| 123 |
+
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
|
| 124 |
+
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
|
| 125 |
+
torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
|
| 126 |
+
], dim=1)
|
| 127 |
+
return torch.cat([
|
| 128 |
+
E.reshape(-1, 16),
|
| 129 |
+
I.reshape(-1, 9),
|
| 130 |
+
], dim=-1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def center_looking_at_camera_pose(
|
| 134 |
+
camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None,
|
| 135 |
+
device: torch.device = torch.device('cpu'),
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
camera_position: (M, 3)
|
| 139 |
+
look_at: (3)
|
| 140 |
+
up_world: (3)
|
| 141 |
+
return: (M, 3, 4)
|
| 142 |
+
"""
|
| 143 |
+
# by default, looking at the origin and world up is pos-z
|
| 144 |
+
if look_at is None:
|
| 145 |
+
look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device)
|
| 146 |
+
if up_world is None:
|
| 147 |
+
up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
|
| 148 |
+
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
| 149 |
+
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
| 150 |
+
|
| 151 |
+
z_axis = camera_position - look_at
|
| 152 |
+
z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
|
| 153 |
+
x_axis = torch.cross(up_world, z_axis)
|
| 154 |
+
x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
|
| 155 |
+
y_axis = torch.cross(z_axis, x_axis)
|
| 156 |
+
y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
|
| 157 |
+
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
| 158 |
+
return extrinsics
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')):
|
| 162 |
+
"""
|
| 163 |
+
n_views: number of surrounding views
|
| 164 |
+
radius: camera dist to center
|
| 165 |
+
height: height of the camera
|
| 166 |
+
return: (M, 3, 4)
|
| 167 |
+
"""
|
| 168 |
+
assert n_views > 0
|
| 169 |
+
assert radius > 0
|
| 170 |
+
|
| 171 |
+
theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device)
|
| 172 |
+
projected_radius = math.sqrt(radius ** 2 - height ** 2)
|
| 173 |
+
x = torch.cos(theta) * projected_radius
|
| 174 |
+
y = torch.sin(theta) * projected_radius
|
| 175 |
+
z = torch.full((n_views,), height, device=device)
|
| 176 |
+
|
| 177 |
+
camera_positions = torch.stack([x, y, z], dim=1)
|
| 178 |
+
extrinsics = center_looking_at_camera_pose(camera_positions, device=device)
|
| 179 |
+
|
| 180 |
+
return extrinsics
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def create_intrinsics(
|
| 184 |
+
f: float,
|
| 185 |
+
c: float = None, cx: float = None, cy: float = None,
|
| 186 |
+
w: float = 1., h: float = 1.,
|
| 187 |
+
dtype: torch.dtype = torch.float32,
|
| 188 |
+
device: torch.device = torch.device('cpu'),
|
| 189 |
+
):
|
| 190 |
+
"""
|
| 191 |
+
return: (3, 2)
|
| 192 |
+
"""
|
| 193 |
+
fx = fy = f
|
| 194 |
+
if c is not None:
|
| 195 |
+
assert cx is None and cy is None, "c and cx/cy cannot be used together"
|
| 196 |
+
cx = cy = c
|
| 197 |
+
else:
|
| 198 |
+
assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided"
|
| 199 |
+
fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1.
|
| 200 |
+
intrinsics = torch.tensor([
|
| 201 |
+
[fx, fy],
|
| 202 |
+
[cx, cy],
|
| 203 |
+
[w, h],
|
| 204 |
+
], dtype=dtype, device=device)
|
| 205 |
+
return intrinsics
|
LHM/datasets/mixer.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");:
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import pdb
|
| 18 |
+
from functools import partial
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
__all__ = ["MixerDataset"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MixerDataset(torch.utils.data.Dataset):
|
| 26 |
+
"""Reference"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
split: str,
|
| 31 |
+
subsets: dict,
|
| 32 |
+
**dataset_kwargs,
|
| 33 |
+
):
|
| 34 |
+
|
| 35 |
+
self.subsets = [
|
| 36 |
+
self._dataset_fn(subset, split)(
|
| 37 |
+
use_flame=subset["use_flame"],
|
| 38 |
+
src_head_size=subset.get("src_head_size", 448),
|
| 39 |
+
**dataset_kwargs,
|
| 40 |
+
)
|
| 41 |
+
for subset in subsets
|
| 42 |
+
]
|
| 43 |
+
self.virtual_lens = [
|
| 44 |
+
math.ceil(subset_config["sample_rate"] * len(subset_obj))
|
| 45 |
+
for subset_config, subset_obj in zip(subsets, self.subsets)
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def _dataset_fn(subset_config: dict, split: str):
|
| 50 |
+
name = subset_config["name"]
|
| 51 |
+
|
| 52 |
+
dataset_cls = None
|
| 53 |
+
if name == "video_human":
|
| 54 |
+
from .video_human import VideoHumanDataset
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError(f"Dataset {name} not implemented")
|
| 58 |
+
|
| 59 |
+
return partial(
|
| 60 |
+
dataset_cls,
|
| 61 |
+
root_dirs=subset_config["root_dirs"],
|
| 62 |
+
meta_path=subset_config["meta_path"][split],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
return sum(self.virtual_lens)
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, idx):
|
| 69 |
+
subset_idx = 0
|
| 70 |
+
virtual_idx = idx
|
| 71 |
+
while virtual_idx >= self.virtual_lens[subset_idx]:
|
| 72 |
+
virtual_idx -= self.virtual_lens[subset_idx]
|
| 73 |
+
subset_idx += 1
|
| 74 |
+
real_idx = virtual_idx % len(self.subsets[subset_idx])
|
| 75 |
+
return self.subsets[subset_idx][real_idx]
|
LHM/launch.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import pdb
|
| 18 |
+
|
| 19 |
+
from LHM.runners import REGISTRY_RUNNERS
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser(description="OpenLRM launcher")
|
| 25 |
+
parser.add_argument("runner", type=str, help="Runner to launch")
|
| 26 |
+
args, unknown = parser.parse_known_args()
|
| 27 |
+
|
| 28 |
+
if args.runner not in REGISTRY_RUNNERS:
|
| 29 |
+
raise ValueError("Runner {} not found".format(args.runner))
|
| 30 |
+
|
| 31 |
+
RunnerClass = REGISTRY_RUNNERS[args.runner]
|
| 32 |
+
with RunnerClass() as runner:
|
| 33 |
+
runner.run()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
LHM/losses/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from .ball_loss import *
|
| 17 |
+
from .offset_loss import *
|
| 18 |
+
from .perceptual import *
|
| 19 |
+
from .pixelwise import *
|
| 20 |
+
from .tvloss import *
|
LHM/losses/ball_loss.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Organization : Alibaba XR-Lab
|
| 3 |
+
# @Author : Lingteng Qiu
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @Time : 2025-03-10 19:08:35
|
| 6 |
+
# @Function : ASAP loss
|
| 7 |
+
import pdb
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
__all__ = ["ASAP_Loss", "Heuristic_ASAP_Loss"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ASAP_Loss(nn.Module):
|
| 16 |
+
|
| 17 |
+
def forward(self, scaling, r=1, **params):
|
| 18 |
+
"""where r is the radius of the ball between max-axis and min-axis."""
|
| 19 |
+
raise NotImplementedError(
|
| 20 |
+
"ASAP_Loss is not implemented yet in Inference version"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Heuristic_ASAP_Loss(nn.Module):
|
| 25 |
+
def __init__(self, group_dict, group_body_mapping):
|
| 26 |
+
super(Heuristic_ASAP_Loss, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.group_dict = group_dict # register weights fro different body parts
|
| 29 |
+
self.group_body_mapping = group_body_mapping # mapping of body parts to group
|
| 30 |
+
|
| 31 |
+
def _heurisitic_loss(self, _ball_loss):
|
| 32 |
+
|
| 33 |
+
_loss = 0.0
|
| 34 |
+
for key in self.group_dict.keys():
|
| 35 |
+
key_weights = self.group_dict[key]
|
| 36 |
+
group_mapping_idx = self.group_body_mapping[key]
|
| 37 |
+
_loss += key_weights * _ball_loss[:, group_mapping_idx].mean()
|
| 38 |
+
|
| 39 |
+
return _loss
|
| 40 |
+
|
| 41 |
+
def forward(self, scaling, r=5, **params):
|
| 42 |
+
"""where r is the radius of the ball between max-axis and min-axis."""
|
| 43 |
+
"human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
|
| 44 |
+
|
| 45 |
+
_scale = scaling
|
| 46 |
+
|
| 47 |
+
_scale_min = torch.min(_scale, dim=-1)[0]
|
| 48 |
+
_scale_max = torch.max(_scale, dim=-1)[0]
|
| 49 |
+
|
| 50 |
+
scale_ratio = _scale_max / (_scale_min + 1e-6)
|
| 51 |
+
|
| 52 |
+
_ball_loss = torch.clamp(scale_ratio, min=r) - r
|
| 53 |
+
|
| 54 |
+
return self._heurisitic_loss(_ball_loss)
|
LHM/losses/offset_loss.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Organization : Alibaba XR-Lab
|
| 3 |
+
# @Author : Lingteng Qiu
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @Time : 2025-03-10 19:08:56
|
| 6 |
+
# @Function : ACAP Loss
|
| 7 |
+
import pdb
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
__all__ = ["ACAP_Loss", "Heuristic_ACAP_Loss"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ACAP_Loss(nn.Module):
|
| 17 |
+
"""As close as possibel loss"""
|
| 18 |
+
|
| 19 |
+
def forward(self, offset, d=0.05625, **params):
|
| 20 |
+
"""Empirically, where d is the thresold of distance points leave from 1.8/32 = 0.0562."""
|
| 21 |
+
|
| 22 |
+
offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
|
| 23 |
+
|
| 24 |
+
return offset_loss.mean()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Heuristic_ACAP_Loss(nn.Module):
|
| 28 |
+
"""As close as possibel loss"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, group_dict, group_body_mapping):
|
| 31 |
+
super(Heuristic_ACAP_Loss, self).__init__()
|
| 32 |
+
|
| 33 |
+
self.group_dict = group_dict # register weights fro different body parts
|
| 34 |
+
self.group_body_mapping = group_body_mapping # mapping of body parts to group
|
| 35 |
+
|
| 36 |
+
def _heurisitic_loss(self, _offset_loss):
|
| 37 |
+
|
| 38 |
+
_loss = 0.0
|
| 39 |
+
for key in self.group_dict.keys():
|
| 40 |
+
key_weights = self.group_dict[key]
|
| 41 |
+
group_mapping_idx = self.group_body_mapping[key]
|
| 42 |
+
_loss += key_weights * _offset_loss[:, group_mapping_idx].mean()
|
| 43 |
+
|
| 44 |
+
return _loss
|
| 45 |
+
|
| 46 |
+
def forward(self, offset, d=0.05625, **params):
|
| 47 |
+
"""Empirically, where d is the thresold of distance points leave from human prior model, 1.8/32 = 0.0562."""
|
| 48 |
+
"human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
|
| 49 |
+
|
| 50 |
+
_offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
|
| 51 |
+
|
| 52 |
+
return self._heurisitic_loss(_offset_loss)
|
LHM/losses/perceptual.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
__all__ = ['LPIPSLoss']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LPIPSLoss(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Compute LPIPS loss between two images.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, device, prefech: bool = False):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.device = device
|
| 30 |
+
self.cached_models = {}
|
| 31 |
+
if prefech:
|
| 32 |
+
self.prefetch_models()
|
| 33 |
+
|
| 34 |
+
def _get_model(self, model_name: str):
|
| 35 |
+
if model_name not in self.cached_models:
|
| 36 |
+
import warnings
|
| 37 |
+
with warnings.catch_warnings():
|
| 38 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 39 |
+
import lpips
|
| 40 |
+
_model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
|
| 41 |
+
_model = torch.compile(_model)
|
| 42 |
+
self.cached_models[model_name] = _model
|
| 43 |
+
return self.cached_models[model_name]
|
| 44 |
+
|
| 45 |
+
def prefetch_models(self):
|
| 46 |
+
_model_names = ['alex', 'vgg']
|
| 47 |
+
for model_name in _model_names:
|
| 48 |
+
self._get_model(model_name)
|
| 49 |
+
|
| 50 |
+
def forward(self, x, y, is_training: bool = True):
|
| 51 |
+
"""
|
| 52 |
+
Assume images are 0-1 scaled and channel first.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x: [N, M, C, H, W]
|
| 56 |
+
y: [N, M, C, H, W]
|
| 57 |
+
is_training: whether to use VGG or AlexNet.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Mean-reduced LPIPS loss across batch.
|
| 61 |
+
"""
|
| 62 |
+
model_name = 'vgg' if is_training else 'alex'
|
| 63 |
+
loss_fn = self._get_model(model_name)
|
| 64 |
+
N, M, C, H, W = x.shape
|
| 65 |
+
x = x.reshape(N*M, C, H, W)
|
| 66 |
+
y = y.reshape(N*M, C, H, W)
|
| 67 |
+
image_loss = loss_fn(x, y, normalize=True).mean(dim=[1, 2, 3])
|
| 68 |
+
batch_loss = image_loss.reshape(N, M).mean(dim=1)
|
| 69 |
+
all_loss = batch_loss.mean()
|
| 70 |
+
return all_loss
|
LHM/losses/pixelwise.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
__all__ = ['PixelLoss']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PixelLoss(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Pixel-wise loss between two images.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, option: str = 'mse'):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.loss_fn = self._build_from_option(option)
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def _build_from_option(option: str, reduction: str = 'none'):
|
| 33 |
+
if option == 'mse':
|
| 34 |
+
return nn.MSELoss(reduction=reduction)
|
| 35 |
+
elif option == 'l1':
|
| 36 |
+
return nn.L1Loss(reduction=reduction)
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError(f'Unknown pixel loss option: {option}')
|
| 39 |
+
|
| 40 |
+
@torch.compile
|
| 41 |
+
def forward(self, x, y):
|
| 42 |
+
"""
|
| 43 |
+
Assume images are channel first.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x: [N, M, C, H, W]
|
| 47 |
+
y: [N, M, C, H, W]
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Mean-reduced pixel loss across batch.
|
| 51 |
+
"""
|
| 52 |
+
N, M, C, H, W = x.shape
|
| 53 |
+
x = x.reshape(N*M, C, H, W)
|
| 54 |
+
y = y.reshape(N*M, C, H, W)
|
| 55 |
+
image_loss = self.loss_fn(x, y).mean(dim=[1, 2, 3])
|
| 56 |
+
batch_loss = image_loss.reshape(N, M).mean(dim=1)
|
| 57 |
+
all_loss = batch_loss.mean()
|
| 58 |
+
return all_loss
|
LHM/losses/tvloss.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
__all__ = ['TVLoss']
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TVLoss(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Total variance loss.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
def numel_excluding_first_dim(self, x):
|
| 31 |
+
return x.numel() // x.shape[0]
|
| 32 |
+
|
| 33 |
+
@torch.compile
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
"""
|
| 36 |
+
Assume batched and channel first with inner sizes.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
x: [N, M, C, H, W]
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Mean-reduced TV loss with element-level scaling.
|
| 43 |
+
"""
|
| 44 |
+
N, M, C, H, W = x.shape
|
| 45 |
+
x = x.reshape(N*M, C, H, W)
|
| 46 |
+
diff_i = x[..., 1:, :] - x[..., :-1, :]
|
| 47 |
+
diff_j = x[..., :, 1:] - x[..., :, :-1]
|
| 48 |
+
div_i = self.numel_excluding_first_dim(diff_i)
|
| 49 |
+
div_j = self.numel_excluding_first_dim(diff_j)
|
| 50 |
+
tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
|
| 51 |
+
tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
|
| 52 |
+
tv = tv_i + tv_j
|
| 53 |
+
batch_tv = tv.reshape(N, M).mean(dim=1)
|
| 54 |
+
all_tv = batch_tv.mean()
|
| 55 |
+
return all_tv
|
LHM/models/ESRGANer_utils.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Organization : Alibaba XR-Lab
|
| 3 |
+
# @Author : Lingteng Qiu
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @Time : 2025-03-1 17:39:52
|
| 6 |
+
# @Function : Function to improve face quality when training.
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import queue
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
sys.path.append("./")
|
| 14 |
+
import threading
|
| 15 |
+
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
|
| 22 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 23 |
+
import pdb
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def avaliable_device():
|
| 30 |
+
if torch.cuda.is_available():
|
| 31 |
+
current_device_id = torch.cuda.current_device()
|
| 32 |
+
device = f"cuda:{current_device_id}"
|
| 33 |
+
else:
|
| 34 |
+
device = "cpu"
|
| 35 |
+
|
| 36 |
+
return device
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RealESRGANer:
|
| 40 |
+
"""A helper class for upsampling images with RealESRGAN.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
| 44 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
| 45 |
+
model (nn.Module): The defined network. Default: None.
|
| 46 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
| 47 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
| 48 |
+
0 denotes for do not use tile. Default: 0.
|
| 49 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
| 50 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
| 51 |
+
half (float): Whether to use half precision during inference. Default: False.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
scale,
|
| 57 |
+
model_path,
|
| 58 |
+
dni_weight=None,
|
| 59 |
+
model=None,
|
| 60 |
+
tile=0,
|
| 61 |
+
tile_pad=10,
|
| 62 |
+
pre_pad=10,
|
| 63 |
+
half=False,
|
| 64 |
+
device=None,
|
| 65 |
+
gpu_id=None,
|
| 66 |
+
):
|
| 67 |
+
self.scale = scale
|
| 68 |
+
self.tile_size = tile
|
| 69 |
+
self.tile_pad = tile_pad
|
| 70 |
+
self.pre_pad = pre_pad
|
| 71 |
+
self.mod_scale = None
|
| 72 |
+
self.half = half
|
| 73 |
+
|
| 74 |
+
# initialize model
|
| 75 |
+
if gpu_id:
|
| 76 |
+
self.device = (
|
| 77 |
+
torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 78 |
+
if device is None
|
| 79 |
+
else device
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
self.device = (
|
| 83 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 84 |
+
if device is None
|
| 85 |
+
else device
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if isinstance(model_path, list):
|
| 89 |
+
# dni
|
| 90 |
+
assert len(model_path) == len(
|
| 91 |
+
dni_weight
|
| 92 |
+
), "model_path and dni_weight should have the save length."
|
| 93 |
+
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
| 94 |
+
else:
|
| 95 |
+
# if the model_path starts with https, it will first download models to the folder: weights
|
| 96 |
+
if model_path.startswith("https://"):
|
| 97 |
+
model_path = load_file_from_url(
|
| 98 |
+
url=model_path,
|
| 99 |
+
model_dir=os.path.join(ROOT_DIR, "weights"),
|
| 100 |
+
progress=True,
|
| 101 |
+
file_name=None,
|
| 102 |
+
)
|
| 103 |
+
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
| 104 |
+
|
| 105 |
+
# prefer to use params_ema
|
| 106 |
+
if "params_ema" in loadnet:
|
| 107 |
+
keyname = "params_ema"
|
| 108 |
+
else:
|
| 109 |
+
keyname = "params"
|
| 110 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
| 111 |
+
|
| 112 |
+
model.eval()
|
| 113 |
+
self.model = model.to(self.device)
|
| 114 |
+
if self.half:
|
| 115 |
+
self.model = self.model.half()
|
| 116 |
+
|
| 117 |
+
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
|
| 118 |
+
"""Deep network interpolation.
|
| 119 |
+
|
| 120 |
+
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
| 121 |
+
"""
|
| 122 |
+
net_a = torch.load(net_a, map_location=torch.device(loc))
|
| 123 |
+
net_b = torch.load(net_b, map_location=torch.device(loc))
|
| 124 |
+
for k, v_a in net_a[key].items():
|
| 125 |
+
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
| 126 |
+
return net_a
|
| 127 |
+
|
| 128 |
+
def pre_process(self, img):
|
| 129 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
|
| 130 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
| 131 |
+
self.img = img.unsqueeze(0).to(self.device)
|
| 132 |
+
if self.half:
|
| 133 |
+
self.img = self.img.half()
|
| 134 |
+
|
| 135 |
+
# pre_pad
|
| 136 |
+
if self.pre_pad != 0:
|
| 137 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
|
| 138 |
+
# mod pad for divisible borders
|
| 139 |
+
if self.scale == 2:
|
| 140 |
+
self.mod_scale = 2
|
| 141 |
+
elif self.scale == 1:
|
| 142 |
+
self.mod_scale = 4
|
| 143 |
+
if self.mod_scale is not None:
|
| 144 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
| 145 |
+
_, _, h, w = self.img.size()
|
| 146 |
+
if h % self.mod_scale != 0:
|
| 147 |
+
self.mod_pad_h = self.mod_scale - h % self.mod_scale
|
| 148 |
+
if w % self.mod_scale != 0:
|
| 149 |
+
self.mod_pad_w = self.mod_scale - w % self.mod_scale
|
| 150 |
+
self.img = F.pad(
|
| 151 |
+
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def process(self):
|
| 155 |
+
# model inference
|
| 156 |
+
self.output = self.model(self.img)
|
| 157 |
+
|
| 158 |
+
def tile_process(self):
|
| 159 |
+
"""It will first crop input images to tiles, and then process each tile.
|
| 160 |
+
Finally, all the processed tiles are merged into one images.
|
| 161 |
+
|
| 162 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
| 163 |
+
"""
|
| 164 |
+
batch, channel, height, width = self.img.shape
|
| 165 |
+
output_height = height * self.scale
|
| 166 |
+
output_width = width * self.scale
|
| 167 |
+
output_shape = (batch, channel, output_height, output_width)
|
| 168 |
+
|
| 169 |
+
# start with black image
|
| 170 |
+
self.output = self.img.new_zeros(output_shape)
|
| 171 |
+
tiles_x = math.ceil(width / self.tile_size)
|
| 172 |
+
tiles_y = math.ceil(height / self.tile_size)
|
| 173 |
+
|
| 174 |
+
# loop over all tiles
|
| 175 |
+
for y in range(tiles_y):
|
| 176 |
+
for x in range(tiles_x):
|
| 177 |
+
# extract tile from input image
|
| 178 |
+
ofs_x = x * self.tile_size
|
| 179 |
+
ofs_y = y * self.tile_size
|
| 180 |
+
# input tile area on total image
|
| 181 |
+
input_start_x = ofs_x
|
| 182 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
| 183 |
+
input_start_y = ofs_y
|
| 184 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
| 185 |
+
|
| 186 |
+
# input tile area on total image with padding
|
| 187 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
| 188 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
| 189 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
| 190 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
| 191 |
+
|
| 192 |
+
# input tile dimensions
|
| 193 |
+
input_tile_width = input_end_x - input_start_x
|
| 194 |
+
input_tile_height = input_end_y - input_start_y
|
| 195 |
+
tile_idx = y * tiles_x + x + 1
|
| 196 |
+
input_tile = self.img[
|
| 197 |
+
:,
|
| 198 |
+
:,
|
| 199 |
+
input_start_y_pad:input_end_y_pad,
|
| 200 |
+
input_start_x_pad:input_end_x_pad,
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
# upscale tile
|
| 204 |
+
try:
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
output_tile = self.model(input_tile)
|
| 207 |
+
except RuntimeError as error:
|
| 208 |
+
print("Error", error)
|
| 209 |
+
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
|
| 210 |
+
|
| 211 |
+
# output tile area on total image
|
| 212 |
+
output_start_x = input_start_x * self.scale
|
| 213 |
+
output_end_x = input_end_x * self.scale
|
| 214 |
+
output_start_y = input_start_y * self.scale
|
| 215 |
+
output_end_y = input_end_y * self.scale
|
| 216 |
+
|
| 217 |
+
# output tile area without padding
|
| 218 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
| 219 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
| 220 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
| 221 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
| 222 |
+
|
| 223 |
+
# put tile into output image
|
| 224 |
+
self.output[
|
| 225 |
+
:, :, output_start_y:output_end_y, output_start_x:output_end_x
|
| 226 |
+
] = output_tile[
|
| 227 |
+
:,
|
| 228 |
+
:,
|
| 229 |
+
output_start_y_tile:output_end_y_tile,
|
| 230 |
+
output_start_x_tile:output_end_x_tile,
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
def post_process(self):
|
| 234 |
+
# remove extra pad
|
| 235 |
+
if self.mod_scale is not None:
|
| 236 |
+
_, _, h, w = self.output.size()
|
| 237 |
+
self.output = self.output[
|
| 238 |
+
:,
|
| 239 |
+
:,
|
| 240 |
+
0 : h - self.mod_pad_h * self.scale,
|
| 241 |
+
0 : w - self.mod_pad_w * self.scale,
|
| 242 |
+
]
|
| 243 |
+
# remove prepad
|
| 244 |
+
if self.pre_pad != 0:
|
| 245 |
+
_, _, h, w = self.output.size()
|
| 246 |
+
self.output = self.output[
|
| 247 |
+
:,
|
| 248 |
+
:,
|
| 249 |
+
0 : h - self.pre_pad * self.scale,
|
| 250 |
+
0 : w - self.pre_pad * self.scale,
|
| 251 |
+
]
|
| 252 |
+
return self.output
|
| 253 |
+
|
| 254 |
+
@torch.no_grad()
|
| 255 |
+
def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
|
| 256 |
+
h_input, w_input = img.shape[0:2]
|
| 257 |
+
# img: numpy
|
| 258 |
+
img = img.astype(np.float32)
|
| 259 |
+
if np.max(img) > 256: # 16-bit image
|
| 260 |
+
max_range = 65535
|
| 261 |
+
print("\tInput is a 16-bit image")
|
| 262 |
+
else:
|
| 263 |
+
max_range = 255
|
| 264 |
+
img = img / max_range
|
| 265 |
+
if len(img.shape) == 2: # gray image
|
| 266 |
+
img_mode = "L"
|
| 267 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 268 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
| 269 |
+
img_mode = "RGBA"
|
| 270 |
+
alpha = img[:, :, 3]
|
| 271 |
+
img = img[:, :, 0:3]
|
| 272 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 273 |
+
if alpha_upsampler == "realesrgan":
|
| 274 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
| 275 |
+
else:
|
| 276 |
+
img_mode = "RGB"
|
| 277 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 278 |
+
|
| 279 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
| 280 |
+
self.pre_process(img)
|
| 281 |
+
if self.tile_size > 0:
|
| 282 |
+
self.tile_process()
|
| 283 |
+
else:
|
| 284 |
+
self.process()
|
| 285 |
+
output_img = self.post_process()
|
| 286 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 287 |
+
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
| 288 |
+
if img_mode == "L":
|
| 289 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
| 290 |
+
|
| 291 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
| 292 |
+
if img_mode == "RGBA":
|
| 293 |
+
if alpha_upsampler == "realesrgan":
|
| 294 |
+
self.pre_process(alpha)
|
| 295 |
+
if self.tile_size > 0:
|
| 296 |
+
self.tile_process()
|
| 297 |
+
else:
|
| 298 |
+
self.process()
|
| 299 |
+
output_alpha = self.post_process()
|
| 300 |
+
output_alpha = (
|
| 301 |
+
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 302 |
+
)
|
| 303 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
| 304 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
| 305 |
+
else: # use the cv2 resize for alpha channel
|
| 306 |
+
h, w = alpha.shape[0:2]
|
| 307 |
+
output_alpha = cv2.resize(
|
| 308 |
+
alpha,
|
| 309 |
+
(w * self.scale, h * self.scale),
|
| 310 |
+
interpolation=cv2.INTER_LINEAR,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# merge the alpha channel
|
| 314 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
| 315 |
+
output_img[:, :, 3] = output_alpha
|
| 316 |
+
|
| 317 |
+
# ------------------------------ return ------------------------------ #
|
| 318 |
+
if max_range == 65535: # 16-bit image
|
| 319 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
| 320 |
+
else:
|
| 321 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
| 322 |
+
|
| 323 |
+
if outscale is not None and outscale != float(self.scale):
|
| 324 |
+
output = cv2.resize(
|
| 325 |
+
output,
|
| 326 |
+
(
|
| 327 |
+
int(w_input * outscale),
|
| 328 |
+
int(h_input * outscale),
|
| 329 |
+
),
|
| 330 |
+
interpolation=cv2.INTER_LANCZOS4,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
return output, img_mode
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class PrefetchReader(threading.Thread):
|
| 337 |
+
"""Prefetch images.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
img_list (list[str]): A image list of image paths to be read.
|
| 341 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
def __init__(self, img_list, num_prefetch_queue):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.que = queue.Queue(num_prefetch_queue)
|
| 347 |
+
self.img_list = img_list
|
| 348 |
+
|
| 349 |
+
def run(self):
|
| 350 |
+
for img_path in self.img_list:
|
| 351 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
| 352 |
+
self.que.put(img)
|
| 353 |
+
|
| 354 |
+
self.que.put(None)
|
| 355 |
+
|
| 356 |
+
def __next__(self):
|
| 357 |
+
next_item = self.que.get()
|
| 358 |
+
if next_item is None:
|
| 359 |
+
raise StopIteration
|
| 360 |
+
return next_item
|
| 361 |
+
|
| 362 |
+
def __iter__(self):
|
| 363 |
+
return self
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class IOConsumer(threading.Thread):
|
| 367 |
+
|
| 368 |
+
def __init__(self, opt, que, qid):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self._queue = que
|
| 371 |
+
self.qid = qid
|
| 372 |
+
self.opt = opt
|
| 373 |
+
|
| 374 |
+
def run(self):
|
| 375 |
+
while True:
|
| 376 |
+
msg = self._queue.get()
|
| 377 |
+
if isinstance(msg, str) and msg == "quit":
|
| 378 |
+
break
|
| 379 |
+
|
| 380 |
+
output = msg["output"]
|
| 381 |
+
save_path = msg["save_path"]
|
| 382 |
+
cv2.imwrite(save_path, output)
|
| 383 |
+
print(f"IO worker {self.qid} is done.")
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class ESRGANEasyModel:
|
| 387 |
+
def __init__(
|
| 388 |
+
self, model_path="./pretrained_models/RealESRGAN_x4plus.pth", face_enhance=True
|
| 389 |
+
):
|
| 390 |
+
model = RRDBNet(
|
| 391 |
+
num_in_ch=3,
|
| 392 |
+
num_out_ch=3,
|
| 393 |
+
num_feat=64,
|
| 394 |
+
num_block=23,
|
| 395 |
+
num_grow_ch=32,
|
| 396 |
+
scale=4,
|
| 397 |
+
)
|
| 398 |
+
self.net_scale = 4
|
| 399 |
+
file_url = [
|
| 400 |
+
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
| 401 |
+
]
|
| 402 |
+
if model_path is None:
|
| 403 |
+
model_path = os.path.join("weights", args.model_name + ".pth")
|
| 404 |
+
if not os.path.isfile(model_path):
|
| 405 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 406 |
+
for url in file_url:
|
| 407 |
+
# model_path will be updated
|
| 408 |
+
model_path = load_file_from_url(
|
| 409 |
+
url=url,
|
| 410 |
+
model_dir=os.path.join("./", "pretrained_models"),
|
| 411 |
+
progress=True,
|
| 412 |
+
file_name=None,
|
| 413 |
+
)
|
| 414 |
+
self.face_enhance = face_enhance
|
| 415 |
+
|
| 416 |
+
dni_weight = None
|
| 417 |
+
|
| 418 |
+
self.upsampler = RealESRGANer(
|
| 419 |
+
scale=self.net_scale,
|
| 420 |
+
model_path=model_path,
|
| 421 |
+
dni_weight=dni_weight,
|
| 422 |
+
model=model,
|
| 423 |
+
tile=0,
|
| 424 |
+
tile_pad=10,
|
| 425 |
+
pre_pad=0,
|
| 426 |
+
half=False,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
self.upsampler.model.to(avaliable_device())
|
| 430 |
+
if face_enhance: # Use GFPGAN for face enhancement
|
| 431 |
+
from gfpgan import GFPGANer
|
| 432 |
+
|
| 433 |
+
self.face_enhancer = GFPGANer(
|
| 434 |
+
model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
| 435 |
+
upscale=4,
|
| 436 |
+
arch="clean",
|
| 437 |
+
channel_multiplier=2,
|
| 438 |
+
bg_upsampler=self.upsampler,
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
self.face_enhancer = None
|
| 442 |
+
|
| 443 |
+
@torch.no_grad()
|
| 444 |
+
def __call__(self, img):
|
| 445 |
+
if self.face_enhancer is not None:
|
| 446 |
+
_, _, output = self.face_enhancer.enhance(
|
| 447 |
+
img, has_aligned=False, only_center_face=False, paste_back=True
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
output, _ = self.upsampler.enhance(img, outscale=4)
|
| 451 |
+
return output
|
| 452 |
+
|
| 453 |
+
def __repr__(self):
|
| 454 |
+
return f"ESRGANEasyModel:\n {self.upsampler}"
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
|
| 459 |
+
import time
|
| 460 |
+
|
| 461 |
+
model = ESRGANEasyModel(face_enhance=True)
|
| 462 |
+
input_img = "./debug/face_debug/gt/head_gt_0.png"
|
| 463 |
+
|
| 464 |
+
img_np = cv2.imread(input_img)
|
| 465 |
+
set1 = [
|
| 466 |
+
"./debug/face_debug/gt/head_gt_0.png",
|
| 467 |
+
"./debug/face_debug/gt/head_gt_1.png",
|
| 468 |
+
"./debug/face_debug/gt/head_gt_2.png",
|
| 469 |
+
"./debug/face_debug/gt/head_gt_3.png",
|
| 470 |
+
"./debug/face_debug/gt/head_gt_4.png",
|
| 471 |
+
"./debug/face_debug/gt/head_gt_5.png",
|
| 472 |
+
"./debug/face_debug/gt/head_gt_6.png",
|
| 473 |
+
"./debug/face_debug/gt/head_gt_0.png",
|
| 474 |
+
]
|
| 475 |
+
img_set1 = [cv2.imread(img_path) for img_path in set1]
|
| 476 |
+
|
| 477 |
+
sr = model(img_set1[0])
|
| 478 |
+
|
| 479 |
+
s0 = time.time()
|
| 480 |
+
for img in img_set1:
|
| 481 |
+
|
| 482 |
+
sr = model(img)
|
LHM/models/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from .modeling_human_lrm import (
|
| 17 |
+
ModelHumanLRM,
|
| 18 |
+
ModelHumanLRMSapdinoBodyHeadSD3_5,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
model_dict = {
|
| 22 |
+
"human_lrm": ModelHumanLRM,
|
| 23 |
+
"human_lrm_sapdino_bh_sd3_5": ModelHumanLRMSapdinoBodyHeadSD3_5,
|
| 24 |
+
}
|
LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
LHM/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (352 Bytes). View file
|
|
|
LHM/models/__pycache__/arcface_utils.cpython-310.pyc
ADDED
|
Binary file (9.73 kB). View file
|
|
|
LHM/models/__pycache__/embedder.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
LHM/models/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (6.89 kB). View file
|
|
|
LHM/models/__pycache__/transformer_dit.cpython-310.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
LHM/models/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
LHM/models/arcface_utils.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Organization : Alibaba XR-Lab
|
| 3 |
+
# @Author : Lingteng Qiu
|
| 4 |
+
# @Email : [email protected]
|
| 5 |
+
# @Time : 2025-03-10 17:38:29
|
| 6 |
+
# @Function : Arc-Similarity Loss
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
sys.path.append(".")
|
| 10 |
+
|
| 11 |
+
import pdb
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
| 21 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
inplanes (int): Channel number of inputs.
|
| 25 |
+
outplanes (int): Channel number of outputs.
|
| 26 |
+
stride (int): Stride in convolution. Default: 1.
|
| 27 |
+
"""
|
| 28 |
+
return nn.Conv2d(
|
| 29 |
+
inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class BasicBlock(nn.Module):
|
| 34 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
inplanes (int): Channel number of inputs.
|
| 38 |
+
planes (int): Channel number of outputs.
|
| 39 |
+
stride (int): Stride in convolution. Default: 1.
|
| 40 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
expansion = 1 # output channel expansion ratio
|
| 44 |
+
|
| 45 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 46 |
+
super(BasicBlock, self).__init__()
|
| 47 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 48 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 49 |
+
self.relu = nn.ReLU(inplace=True)
|
| 50 |
+
self.conv2 = conv3x3(planes, planes)
|
| 51 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 52 |
+
self.downsample = downsample
|
| 53 |
+
self.stride = stride
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
residual = x
|
| 57 |
+
|
| 58 |
+
out = self.conv1(x)
|
| 59 |
+
out = self.bn1(out)
|
| 60 |
+
out = self.relu(out)
|
| 61 |
+
|
| 62 |
+
out = self.conv2(out)
|
| 63 |
+
out = self.bn2(out)
|
| 64 |
+
|
| 65 |
+
if self.downsample is not None:
|
| 66 |
+
residual = self.downsample(x)
|
| 67 |
+
|
| 68 |
+
out += residual
|
| 69 |
+
out = self.relu(out)
|
| 70 |
+
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class IRBlock(nn.Module):
|
| 75 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
inplanes (int): Channel number of inputs.
|
| 79 |
+
planes (int): Channel number of outputs.
|
| 80 |
+
stride (int): Stride in convolution. Default: 1.
|
| 81 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 82 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
expansion = 1 # output channel expansion ratio
|
| 86 |
+
|
| 87 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
| 88 |
+
super(IRBlock, self).__init__()
|
| 89 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
| 90 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
| 91 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
| 92 |
+
self.prelu = nn.PReLU()
|
| 93 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
| 94 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 95 |
+
self.downsample = downsample
|
| 96 |
+
self.stride = stride
|
| 97 |
+
self.use_se = use_se
|
| 98 |
+
if self.use_se:
|
| 99 |
+
self.se = SEBlock(planes)
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
residual = x
|
| 103 |
+
out = self.bn0(x)
|
| 104 |
+
out = self.conv1(out)
|
| 105 |
+
out = self.bn1(out)
|
| 106 |
+
out = self.prelu(out)
|
| 107 |
+
|
| 108 |
+
out = self.conv2(out)
|
| 109 |
+
out = self.bn2(out)
|
| 110 |
+
if self.use_se:
|
| 111 |
+
out = self.se(out)
|
| 112 |
+
|
| 113 |
+
if self.downsample is not None:
|
| 114 |
+
residual = self.downsample(x)
|
| 115 |
+
|
| 116 |
+
out += residual
|
| 117 |
+
out = self.prelu(out)
|
| 118 |
+
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Bottleneck(nn.Module):
|
| 123 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
inplanes (int): Channel number of inputs.
|
| 127 |
+
planes (int): Channel number of outputs.
|
| 128 |
+
stride (int): Stride in convolution. Default: 1.
|
| 129 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
expansion = 4 # output channel expansion ratio
|
| 133 |
+
|
| 134 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 135 |
+
super(Bottleneck, self).__init__()
|
| 136 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 137 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 138 |
+
self.conv2 = nn.Conv2d(
|
| 139 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
| 140 |
+
)
|
| 141 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 142 |
+
self.conv3 = nn.Conv2d(
|
| 143 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
| 144 |
+
)
|
| 145 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 146 |
+
self.relu = nn.ReLU(inplace=True)
|
| 147 |
+
self.downsample = downsample
|
| 148 |
+
self.stride = stride
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
residual = x
|
| 152 |
+
|
| 153 |
+
out = self.conv1(x)
|
| 154 |
+
out = self.bn1(out)
|
| 155 |
+
out = self.relu(out)
|
| 156 |
+
|
| 157 |
+
out = self.conv2(out)
|
| 158 |
+
out = self.bn2(out)
|
| 159 |
+
out = self.relu(out)
|
| 160 |
+
|
| 161 |
+
out = self.conv3(out)
|
| 162 |
+
out = self.bn3(out)
|
| 163 |
+
|
| 164 |
+
if self.downsample is not None:
|
| 165 |
+
residual = self.downsample(x)
|
| 166 |
+
|
| 167 |
+
out += residual
|
| 168 |
+
out = self.relu(out)
|
| 169 |
+
|
| 170 |
+
return out
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class SEBlock(nn.Module):
|
| 174 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
channel (int): Channel number of inputs.
|
| 178 |
+
reduction (int): Channel reduction ration. Default: 16.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, channel, reduction=16):
|
| 182 |
+
super(SEBlock, self).__init__()
|
| 183 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(
|
| 184 |
+
1
|
| 185 |
+
) # pool to 1x1 without spatial information
|
| 186 |
+
self.fc = nn.Sequential(
|
| 187 |
+
nn.Linear(channel, channel // reduction),
|
| 188 |
+
nn.PReLU(),
|
| 189 |
+
nn.Linear(channel // reduction, channel),
|
| 190 |
+
nn.Sigmoid(),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def forward(self, x):
|
| 194 |
+
b, c, _, _ = x.size()
|
| 195 |
+
y = self.avg_pool(x).view(b, c)
|
| 196 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 197 |
+
return x * y
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class ResNetArcFace(nn.Module):
|
| 201 |
+
"""ArcFace with ResNet architectures.
|
| 202 |
+
|
| 203 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
block (str): Block used in the ArcFace architecture.
|
| 207 |
+
layers (tuple(int)): Block numbers in each layer.
|
| 208 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
block="IRBlock",
|
| 214 |
+
layers=[2, 2, 2, 2],
|
| 215 |
+
use_se=False,
|
| 216 |
+
pretrain_model="./pretrained_models/arcface_resnet18.pth",
|
| 217 |
+
):
|
| 218 |
+
if block == "IRBlock":
|
| 219 |
+
block = IRBlock
|
| 220 |
+
self.inplanes = 64
|
| 221 |
+
self.use_se = use_se
|
| 222 |
+
super(ResNetArcFace, self).__init__()
|
| 223 |
+
|
| 224 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
| 225 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 226 |
+
self.prelu = nn.PReLU()
|
| 227 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 228 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 229 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 230 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 231 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 232 |
+
self.bn4 = nn.BatchNorm2d(512)
|
| 233 |
+
self.dropout = nn.Dropout()
|
| 234 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
| 235 |
+
self.bn5 = nn.BatchNorm1d(512)
|
| 236 |
+
|
| 237 |
+
# initialization
|
| 238 |
+
for m in self.modules():
|
| 239 |
+
if isinstance(m, nn.Conv2d):
|
| 240 |
+
nn.init.xavier_normal_(m.weight)
|
| 241 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
| 242 |
+
nn.init.constant_(m.weight, 1)
|
| 243 |
+
nn.init.constant_(m.bias, 0)
|
| 244 |
+
elif isinstance(m, nn.Linear):
|
| 245 |
+
nn.init.xavier_normal_(m.weight)
|
| 246 |
+
nn.init.constant_(m.bias, 0)
|
| 247 |
+
|
| 248 |
+
if pretrain_model is not None:
|
| 249 |
+
self.load_network(self, pretrain_model, strict=True, param_key=None)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError("Please specify the pretrain model path.")
|
| 252 |
+
|
| 253 |
+
self.freeze()
|
| 254 |
+
|
| 255 |
+
@staticmethod
|
| 256 |
+
def load_network(net, load_path, strict=True, param_key=None):
|
| 257 |
+
|
| 258 |
+
def get_bare_model(net):
|
| 259 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 260 |
+
net = net.module
|
| 261 |
+
return net
|
| 262 |
+
|
| 263 |
+
net = get_bare_model(net)
|
| 264 |
+
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
|
| 265 |
+
if param_key is not None:
|
| 266 |
+
if param_key not in load_net and "params" in load_net:
|
| 267 |
+
param_key = "params"
|
| 268 |
+
load_net = load_net[param_key]
|
| 269 |
+
# remove unnecessary 'module.'
|
| 270 |
+
for k, v in deepcopy(load_net).items():
|
| 271 |
+
if k.startswith("module."):
|
| 272 |
+
load_net[k[7:]] = v
|
| 273 |
+
load_net.pop(k)
|
| 274 |
+
ret = net.load_state_dict(load_net, strict=strict)
|
| 275 |
+
print(ret)
|
| 276 |
+
|
| 277 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
| 278 |
+
downsample = None
|
| 279 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 280 |
+
downsample = nn.Sequential(
|
| 281 |
+
nn.Conv2d(
|
| 282 |
+
self.inplanes,
|
| 283 |
+
planes * block.expansion,
|
| 284 |
+
kernel_size=1,
|
| 285 |
+
stride=stride,
|
| 286 |
+
bias=False,
|
| 287 |
+
),
|
| 288 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 289 |
+
)
|
| 290 |
+
layers = []
|
| 291 |
+
layers.append(
|
| 292 |
+
block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
|
| 293 |
+
)
|
| 294 |
+
self.inplanes = planes
|
| 295 |
+
for _ in range(1, num_blocks):
|
| 296 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
| 297 |
+
|
| 298 |
+
return nn.Sequential(*layers)
|
| 299 |
+
|
| 300 |
+
def forward(self, x):
|
| 301 |
+
x = self.conv1(x)
|
| 302 |
+
x = self.bn1(x)
|
| 303 |
+
x = self.prelu(x)
|
| 304 |
+
x = self.maxpool(x)
|
| 305 |
+
|
| 306 |
+
x = self.layer1(x)
|
| 307 |
+
x = self.layer2(x)
|
| 308 |
+
x = self.layer3(x)
|
| 309 |
+
x = self.layer4(x)
|
| 310 |
+
x = self.bn4(x)
|
| 311 |
+
x = self.dropout(x)
|
| 312 |
+
x = x.view(x.size(0), -1)
|
| 313 |
+
x = self.fc5(x)
|
| 314 |
+
x = self.bn5(x)
|
| 315 |
+
|
| 316 |
+
return x
|
| 317 |
+
|
| 318 |
+
def freeze(self):
|
| 319 |
+
self.eval()
|
| 320 |
+
for param in self.parameters():
|
| 321 |
+
param.requires_grad = False
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
if __name__ == "__main__":
|
| 325 |
+
model = ResNetArcFace()
|
| 326 |
+
model.cuda()
|
| 327 |
+
model.eval()
|
| 328 |
+
# model.eval()
|
| 329 |
+
|
| 330 |
+
set1 = [
|
| 331 |
+
"./debug/face_debug/gt/head_gt_0.png",
|
| 332 |
+
"./debug/face_debug/gt/head_gt_1.png",
|
| 333 |
+
"./debug/face_debug/gt/head_gt_2.png",
|
| 334 |
+
"./debug/face_debug/gt/head_gt_3.png",
|
| 335 |
+
"./debug/face_debug/gt/head_gt_4.png",
|
| 336 |
+
"./debug/face_debug/gt/head_gt_5.png",
|
| 337 |
+
"./debug/face_debug/gt/head_gt_6.png",
|
| 338 |
+
]
|
| 339 |
+
import cv2
|
| 340 |
+
|
| 341 |
+
img_set1 = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in set1]
|
| 342 |
+
|
| 343 |
+
F1_list = []
|
| 344 |
+
|
| 345 |
+
f1_scores = []
|
| 346 |
+
for img in img_set1:
|
| 347 |
+
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0) / 255.0
|
| 348 |
+
img = img.cuda()
|
| 349 |
+
F1 = model(img)
|
| 350 |
+
F1_list.append(F1)
|
| 351 |
+
for i in range(len(F1_list)):
|
| 352 |
+
for j in range(len(F1_list)):
|
| 353 |
+
f1_scores.append(F.l1_loss(F1_list[i], F1_list[j]))
|
| 354 |
+
|
| 355 |
+
print(len(f1_scores))
|
| 356 |
+
|
| 357 |
+
f1_scores = torch.tensor(f1_scores)
|
| 358 |
+
print(f1_scores)
|
| 359 |
+
f1_scores = f1_scores.view(len(F1_list), len(F1_list))
|
| 360 |
+
print(f1_scores)
|
LHM/models/block.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from .modulate import ModLN
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BasicBlock(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Transformer block that is in its simplest form.
|
| 24 |
+
Designed for PF-LRM architecture.
|
| 25 |
+
"""
|
| 26 |
+
# Block contains a self-attention layer and an MLP
|
| 27 |
+
def __init__(self, inner_dim: int, num_heads: int, eps: float,
|
| 28 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
| 29 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
| 32 |
+
self.self_attn = nn.MultiheadAttention(
|
| 33 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
| 34 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
| 35 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
| 36 |
+
self.mlp = nn.Sequential(
|
| 37 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
| 38 |
+
nn.GELU(),
|
| 39 |
+
nn.Dropout(mlp_drop),
|
| 40 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
| 41 |
+
nn.Dropout(mlp_drop),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
# x: [N, L, D]
|
| 46 |
+
before_sa = self.norm1(x)
|
| 47 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
| 48 |
+
x = x + self.mlp(self.norm2(x))
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ConditionBlock(nn.Module):
|
| 53 |
+
"""
|
| 54 |
+
Transformer block that takes in a cross-attention condition.
|
| 55 |
+
Designed for SparseLRM architecture.
|
| 56 |
+
"""
|
| 57 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
| 58 |
+
def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
|
| 59 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
| 60 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
| 63 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 64 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
| 65 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
| 66 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
| 67 |
+
self.self_attn = nn.MultiheadAttention(
|
| 68 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
| 69 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
| 70 |
+
self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
|
| 71 |
+
self.mlp = nn.Sequential(
|
| 72 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
| 73 |
+
nn.GELU(),
|
| 74 |
+
nn.Dropout(mlp_drop),
|
| 75 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
| 76 |
+
nn.Dropout(mlp_drop),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, cond):
|
| 80 |
+
# x: [N, L, D]
|
| 81 |
+
# cond: [N, L_cond, D_cond]
|
| 82 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
|
| 83 |
+
before_sa = self.norm2(x)
|
| 84 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
| 85 |
+
x = x + self.mlp(self.norm3(x))
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ConditionModulationBlock(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
| 92 |
+
Designed for raw LRM architecture.
|
| 93 |
+
"""
|
| 94 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
| 95 |
+
def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
|
| 96 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
| 97 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.norm1 = ModLN(inner_dim, mod_dim, eps)
|
| 100 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 101 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
| 102 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
| 103 |
+
self.norm2 = ModLN(inner_dim, mod_dim, eps)
|
| 104 |
+
self.self_attn = nn.MultiheadAttention(
|
| 105 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
| 106 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
| 107 |
+
self.norm3 = ModLN(inner_dim, mod_dim, eps)
|
| 108 |
+
self.mlp = nn.Sequential(
|
| 109 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
| 110 |
+
nn.GELU(),
|
| 111 |
+
nn.Dropout(mlp_drop),
|
| 112 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
| 113 |
+
nn.Dropout(mlp_drop),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x, cond, mod):
|
| 117 |
+
# x: [N, L, D]
|
| 118 |
+
# cond: [N, L_cond, D_cond]
|
| 119 |
+
# mod: [N, D_mod]
|
| 120 |
+
x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
|
| 121 |
+
before_sa = self.norm2(x, mod)
|
| 122 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
| 123 |
+
x = x + self.mlp(self.norm3(x, mod))
|
| 124 |
+
return x
|
LHM/models/discriminator.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ported from Paella
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 10 |
+
|
| 11 |
+
import functools
|
| 12 |
+
# import torch.nn as nn
|
| 13 |
+
from taming.modules.util import ActNorm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
|
| 17 |
+
class Discriminator(ModelMixin, ConfigMixin):
|
| 18 |
+
@register_to_config
|
| 19 |
+
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
|
| 20 |
+
super().__init__()
|
| 21 |
+
d = max(depth - 3, 3)
|
| 22 |
+
layers = [
|
| 23 |
+
nn.utils.spectral_norm(
|
| 24 |
+
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
|
| 25 |
+
),
|
| 26 |
+
nn.LeakyReLU(0.2),
|
| 27 |
+
]
|
| 28 |
+
for i in range(depth - 1):
|
| 29 |
+
c_in = hidden_channels // (2 ** max((d - i), 0))
|
| 30 |
+
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
|
| 31 |
+
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
| 32 |
+
layers.append(nn.InstanceNorm2d(c_out))
|
| 33 |
+
layers.append(nn.LeakyReLU(0.2))
|
| 34 |
+
self.encoder = nn.Sequential(*layers)
|
| 35 |
+
self.shuffle = nn.Conv2d(
|
| 36 |
+
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
|
| 37 |
+
)
|
| 38 |
+
# self.logits = nn.Sigmoid()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def forward(self, x, cond=None):
|
| 42 |
+
x = self.encoder(x)
|
| 43 |
+
if cond is not None:
|
| 44 |
+
cond = cond.view(
|
| 45 |
+
cond.size(0),
|
| 46 |
+
cond.size(1),
|
| 47 |
+
1,
|
| 48 |
+
1,
|
| 49 |
+
).expand(-1, -1, x.size(-2), x.size(-1))
|
| 50 |
+
x = torch.cat([x, cond], dim=1)
|
| 51 |
+
x = self.shuffle(x)
|
| 52 |
+
# x = self.logits(x)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def weights_init(m):
|
| 59 |
+
classname = m.__class__.__name__
|
| 60 |
+
if classname.find('Conv') != -1:
|
| 61 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
| 62 |
+
elif classname.find('BatchNorm') != -1:
|
| 63 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| 64 |
+
nn.init.constant_(m.bias.data, 0)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class NLayerDiscriminator(nn.Module):
|
| 68 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
| 69 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
| 72 |
+
"""Construct a PatchGAN discriminator
|
| 73 |
+
Parameters:
|
| 74 |
+
input_nc (int) -- the number of channels in input images
|
| 75 |
+
ndf (int) -- the number of filters in the last conv layer
|
| 76 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
| 77 |
+
norm_layer -- normalization layer
|
| 78 |
+
"""
|
| 79 |
+
super(NLayerDiscriminator, self).__init__()
|
| 80 |
+
if not use_actnorm:
|
| 81 |
+
# norm_layer = nn.BatchNorm2d
|
| 82 |
+
norm_layer = nn.InstanceNorm2d
|
| 83 |
+
else:
|
| 84 |
+
norm_layer = ActNorm
|
| 85 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
| 86 |
+
# use_bias = norm_layer.func != nn.BatchNorm2d
|
| 87 |
+
use_bias = norm_layer.func != nn.InstanceNorm2d
|
| 88 |
+
else:
|
| 89 |
+
# use_bias = norm_layer != nn.BatchNorm2d
|
| 90 |
+
use_bias = norm_layer != nn.InstanceNorm2d
|
| 91 |
+
|
| 92 |
+
kw = 4
|
| 93 |
+
padw = 1
|
| 94 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]
|
| 95 |
+
nf_mult = 1
|
| 96 |
+
nf_mult_prev = 1
|
| 97 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 98 |
+
nf_mult_prev = nf_mult
|
| 99 |
+
nf_mult = min(2 ** n, 8)
|
| 100 |
+
sequence += [
|
| 101 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| 102 |
+
norm_layer(ndf * nf_mult),
|
| 103 |
+
nn.LeakyReLU(0.2, False)
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
nf_mult_prev = nf_mult
|
| 107 |
+
nf_mult = min(2 ** n_layers, 8)
|
| 108 |
+
sequence += [
|
| 109 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| 110 |
+
norm_layer(ndf * nf_mult),
|
| 111 |
+
nn.LeakyReLU(0.2, False)
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
sequence += [
|
| 115 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
| 116 |
+
self.main = nn.Sequential(*sequence)
|
| 117 |
+
|
| 118 |
+
def forward(self, input):
|
| 119 |
+
"""Standard forward."""
|
| 120 |
+
return self.main(input)
|
LHM/models/embedder.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CameraEmbedder(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Embed camera features to a high-dimensional vector.
|
| 23 |
+
|
| 24 |
+
Reference:
|
| 25 |
+
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, raw_dim: int, embed_dim: int):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.mlp = nn.Sequential(
|
| 30 |
+
nn.Linear(raw_dim, embed_dim),
|
| 31 |
+
nn.SiLU(),
|
| 32 |
+
nn.Linear(embed_dim, embed_dim),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@torch.compile
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.mlp(x)
|
LHM/models/encoders/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Empty
|
LHM/models/encoders/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc
ADDED
|
Binary file (9.24 kB). View file
|
|
|
LHM/models/encoders/dino_wrapper.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from transformers import ViTImageProcessor, ViTModel
|
| 19 |
+
from accelerate.logging import get_logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DinoWrapper(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Dino v1 wrapper using huggingface transformer implementation.
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, model_name: str, freeze: bool = True, encoder_feat_dim: int = 384):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.model, self.processor = self._build_dino(model_name)
|
| 32 |
+
if freeze:
|
| 33 |
+
self._freeze()
|
| 34 |
+
|
| 35 |
+
@torch.compile
|
| 36 |
+
def forward_model(self, inputs):
|
| 37 |
+
return self.model(**inputs, interpolate_pos_encoding=True)
|
| 38 |
+
|
| 39 |
+
def forward(self, image):
|
| 40 |
+
# image: [N, C, H, W], on cpu
|
| 41 |
+
# RGB image with [0,1] scale and properly sized
|
| 42 |
+
inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
|
| 43 |
+
# This resampling of positional embedding uses bicubic interpolation
|
| 44 |
+
outputs = self.forward_model(inputs)
|
| 45 |
+
last_hidden_states = outputs.last_hidden_state
|
| 46 |
+
return last_hidden_states
|
| 47 |
+
|
| 48 |
+
def _freeze(self):
|
| 49 |
+
logger.warning(f"======== Freezing DinoWrapper ========")
|
| 50 |
+
self.model.eval()
|
| 51 |
+
for name, param in self.model.named_parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
| 56 |
+
import requests
|
| 57 |
+
try:
|
| 58 |
+
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
| 59 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
| 60 |
+
return model, processor
|
| 61 |
+
except requests.exceptions.ProxyError as err:
|
| 62 |
+
if proxy_error_retries > 0:
|
| 63 |
+
print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
|
| 64 |
+
import time
|
| 65 |
+
time.sleep(proxy_error_cooldown)
|
| 66 |
+
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
| 67 |
+
else:
|
| 68 |
+
raise err
|
LHM/models/encoders/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023-2024, Zexin He
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Empty
|
LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
LHM/models/encoders/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (203 Bytes). View file
|
|
|
LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc
ADDED
|
Binary file (4.47 kB). View file
|
|
|
LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
LHM/models/encoders/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
| 60 |
+
state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
|
| 61 |
+
if vit_kwargs.get("modulation_dim") is not None:
|
| 62 |
+
state_dict = {
|
| 63 |
+
k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
|
| 64 |
+
for k, v in state_dict.items()
|
| 65 |
+
}
|
| 66 |
+
model.load_state_dict(state_dict, strict=False)
|
| 67 |
+
else:
|
| 68 |
+
model.load_state_dict(state_dict, strict=True)
|
| 69 |
+
# ********************************************************
|
| 70 |
+
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 75 |
+
"""
|
| 76 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 77 |
+
"""
|
| 78 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 82 |
+
"""
|
| 83 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 84 |
+
"""
|
| 85 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 89 |
+
"""
|
| 90 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 91 |
+
"""
|
| 92 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 96 |
+
"""
|
| 97 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 98 |
+
"""
|
| 99 |
+
return _make_dinov2_model(
|
| 100 |
+
arch_name="vit_giant2",
|
| 101 |
+
ffn_layer="swiglufused",
|
| 102 |
+
weights=weights,
|
| 103 |
+
pretrained=pretrained,
|
| 104 |
+
**kwargs,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 109 |
+
"""
|
| 110 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 111 |
+
"""
|
| 112 |
+
return _make_dinov2_model(
|
| 113 |
+
arch_name="vit_small",
|
| 114 |
+
pretrained=pretrained,
|
| 115 |
+
weights=weights,
|
| 116 |
+
num_register_tokens=4,
|
| 117 |
+
interpolate_antialias=True,
|
| 118 |
+
interpolate_offset=0.0,
|
| 119 |
+
**kwargs,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 124 |
+
"""
|
| 125 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 126 |
+
"""
|
| 127 |
+
return _make_dinov2_model(
|
| 128 |
+
arch_name="vit_base",
|
| 129 |
+
pretrained=pretrained,
|
| 130 |
+
weights=weights,
|
| 131 |
+
num_register_tokens=4,
|
| 132 |
+
interpolate_antialias=True,
|
| 133 |
+
interpolate_offset=0.0,
|
| 134 |
+
**kwargs,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 139 |
+
"""
|
| 140 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 141 |
+
"""
|
| 142 |
+
return _make_dinov2_model(
|
| 143 |
+
arch_name="vit_large",
|
| 144 |
+
pretrained=pretrained,
|
| 145 |
+
weights=weights,
|
| 146 |
+
num_register_tokens=4,
|
| 147 |
+
interpolate_antialias=True,
|
| 148 |
+
interpolate_offset=0.0,
|
| 149 |
+
**kwargs,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 154 |
+
"""
|
| 155 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 156 |
+
"""
|
| 157 |
+
return _make_dinov2_model(
|
| 158 |
+
arch_name="vit_giant2",
|
| 159 |
+
ffn_layer="swiglufused",
|
| 160 |
+
weights=weights,
|
| 161 |
+
pretrained=pretrained,
|
| 162 |
+
num_register_tokens=4,
|
| 163 |
+
interpolate_antialias=True,
|
| 164 |
+
interpolate_offset=0.0,
|
| 165 |
+
**kwargs,
|
| 166 |
+
)
|
LHM/models/encoders/dinov2/hub/classifiers.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from .backbones import _make_dinov2_model
|
| 13 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Weights(Enum):
|
| 17 |
+
IMAGENET1K = "IMAGENET1K"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _make_dinov2_linear_classification_head(
|
| 21 |
+
*,
|
| 22 |
+
arch_name: str = "vit_large",
|
| 23 |
+
patch_size: int = 14,
|
| 24 |
+
embed_dim: int = 1024,
|
| 25 |
+
layers: int = 4,
|
| 26 |
+
pretrained: bool = True,
|
| 27 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
| 28 |
+
num_register_tokens: int = 0,
|
| 29 |
+
**kwargs,
|
| 30 |
+
):
|
| 31 |
+
if layers not in (1, 4):
|
| 32 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
| 33 |
+
if isinstance(weights, str):
|
| 34 |
+
try:
|
| 35 |
+
weights = Weights[weights]
|
| 36 |
+
except KeyError:
|
| 37 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 38 |
+
|
| 39 |
+
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
|
| 40 |
+
|
| 41 |
+
if pretrained:
|
| 42 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 43 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 44 |
+
layers_str = str(layers) if layers == 4 else ""
|
| 45 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
|
| 46 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 47 |
+
linear_head.load_state_dict(state_dict, strict=True)
|
| 48 |
+
|
| 49 |
+
return linear_head
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _LinearClassifierWrapper(nn.Module):
|
| 53 |
+
def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.backbone = backbone
|
| 56 |
+
self.linear_head = linear_head
|
| 57 |
+
self.layers = layers
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
if self.layers == 1:
|
| 61 |
+
x = self.backbone.forward_features(x)
|
| 62 |
+
cls_token = x["x_norm_clstoken"]
|
| 63 |
+
patch_tokens = x["x_norm_patchtokens"]
|
| 64 |
+
# fmt: off
|
| 65 |
+
linear_input = torch.cat([
|
| 66 |
+
cls_token,
|
| 67 |
+
patch_tokens.mean(dim=1),
|
| 68 |
+
], dim=1)
|
| 69 |
+
# fmt: on
|
| 70 |
+
elif self.layers == 4:
|
| 71 |
+
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
|
| 72 |
+
# fmt: off
|
| 73 |
+
linear_input = torch.cat([
|
| 74 |
+
x[0][1],
|
| 75 |
+
x[1][1],
|
| 76 |
+
x[2][1],
|
| 77 |
+
x[3][1],
|
| 78 |
+
x[3][0].mean(dim=1),
|
| 79 |
+
], dim=1)
|
| 80 |
+
# fmt: on
|
| 81 |
+
else:
|
| 82 |
+
assert False, f"Unsupported number of layers: {self.layers}"
|
| 83 |
+
return self.linear_head(linear_input)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _make_dinov2_linear_classifier(
|
| 87 |
+
*,
|
| 88 |
+
arch_name: str = "vit_large",
|
| 89 |
+
layers: int = 4,
|
| 90 |
+
pretrained: bool = True,
|
| 91 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
| 92 |
+
num_register_tokens: int = 0,
|
| 93 |
+
interpolate_antialias: bool = False,
|
| 94 |
+
interpolate_offset: float = 0.1,
|
| 95 |
+
**kwargs,
|
| 96 |
+
):
|
| 97 |
+
backbone = _make_dinov2_model(
|
| 98 |
+
arch_name=arch_name,
|
| 99 |
+
pretrained=pretrained,
|
| 100 |
+
num_register_tokens=num_register_tokens,
|
| 101 |
+
interpolate_antialias=interpolate_antialias,
|
| 102 |
+
interpolate_offset=interpolate_offset,
|
| 103 |
+
**kwargs,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
embed_dim = backbone.embed_dim
|
| 107 |
+
patch_size = backbone.patch_size
|
| 108 |
+
linear_head = _make_dinov2_linear_classification_head(
|
| 109 |
+
arch_name=arch_name,
|
| 110 |
+
patch_size=patch_size,
|
| 111 |
+
embed_dim=embed_dim,
|
| 112 |
+
layers=layers,
|
| 113 |
+
pretrained=pretrained,
|
| 114 |
+
weights=weights,
|
| 115 |
+
num_register_tokens=num_register_tokens,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def dinov2_vits14_lc(
|
| 122 |
+
*,
|
| 123 |
+
layers: int = 4,
|
| 124 |
+
pretrained: bool = True,
|
| 125 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
| 126 |
+
**kwargs,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 130 |
+
"""
|
| 131 |
+
return _make_dinov2_linear_classifier(
|
| 132 |
+
arch_name="vit_small",
|
| 133 |
+
layers=layers,
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
**kwargs,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def dinov2_vitb14_lc(
|
| 141 |
+
*,
|
| 142 |
+
layers: int = 4,
|
| 143 |
+
pretrained: bool = True,
|
| 144 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
| 145 |
+
**kwargs,
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 149 |
+
"""
|
| 150 |
+
return _make_dinov2_linear_classifier(
|
| 151 |
+
arch_name="vit_base",
|
| 152 |
+
layers=layers,
|
| 153 |
+
pretrained=pretrained,
|
| 154 |
+
weights=weights,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def dinov2_vitl14_lc(
|
| 160 |
+
*,
|
| 161 |
+
layers: int = 4,
|
| 162 |
+
pretrained: bool = True,
|
| 163 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
| 164 |
+
**kwargs,
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 168 |
+
"""
|
| 169 |
+
return _make_dinov2_linear_classifier(
|
| 170 |
+
arch_name="vit_large",
|
| 171 |
+
layers=layers,
|
| 172 |
+
pretrained=pretrained,
|
| 173 |
+
weights=weights,
|
| 174 |
+
**kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def dinov2_vitg14_lc(
|
| 179 |
+
*,
|
| 180 |
+
layers: int = 4,
|
| 181 |
+
pretrained: bool = True,
|
| 182 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
| 183 |
+
**kwargs,
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 187 |
+
"""
|
| 188 |
+
return _make_dinov2_linear_classifier(
|
| 189 |
+
arch_name="vit_giant2",
|
| 190 |
+
layers=layers,
|
| 191 |
+
ffn_layer="swiglufused",
|
| 192 |
+
pretrained=pretrained,
|
| 193 |
+
weights=weights,
|
| 194 |
+
**kwargs,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def dinov2_vits14_reg_lc(
|
| 199 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 203 |
+
"""
|
| 204 |
+
return _make_dinov2_linear_classifier(
|
| 205 |
+
arch_name="vit_small",
|
| 206 |
+
layers=layers,
|
| 207 |
+
pretrained=pretrained,
|
| 208 |
+
weights=weights,
|
| 209 |
+
num_register_tokens=4,
|
| 210 |
+
interpolate_antialias=True,
|
| 211 |
+
interpolate_offset=0.0,
|
| 212 |
+
**kwargs,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def dinov2_vitb14_reg_lc(
|
| 217 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 221 |
+
"""
|
| 222 |
+
return _make_dinov2_linear_classifier(
|
| 223 |
+
arch_name="vit_base",
|
| 224 |
+
layers=layers,
|
| 225 |
+
pretrained=pretrained,
|
| 226 |
+
weights=weights,
|
| 227 |
+
num_register_tokens=4,
|
| 228 |
+
interpolate_antialias=True,
|
| 229 |
+
interpolate_offset=0.0,
|
| 230 |
+
**kwargs,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def dinov2_vitl14_reg_lc(
|
| 235 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
| 236 |
+
):
|
| 237 |
+
"""
|
| 238 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 239 |
+
"""
|
| 240 |
+
return _make_dinov2_linear_classifier(
|
| 241 |
+
arch_name="vit_large",
|
| 242 |
+
layers=layers,
|
| 243 |
+
pretrained=pretrained,
|
| 244 |
+
weights=weights,
|
| 245 |
+
num_register_tokens=4,
|
| 246 |
+
interpolate_antialias=True,
|
| 247 |
+
interpolate_offset=0.0,
|
| 248 |
+
**kwargs,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def dinov2_vitg14_reg_lc(
|
| 253 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
| 257 |
+
"""
|
| 258 |
+
return _make_dinov2_linear_classifier(
|
| 259 |
+
arch_name="vit_giant2",
|
| 260 |
+
layers=layers,
|
| 261 |
+
ffn_layer="swiglufused",
|
| 262 |
+
pretrained=pretrained,
|
| 263 |
+
weights=weights,
|
| 264 |
+
num_register_tokens=4,
|
| 265 |
+
interpolate_antialias=True,
|
| 266 |
+
interpolate_offset=0.0,
|
| 267 |
+
**kwargs,
|
| 268 |
+
)
|
LHM/models/encoders/dinov2/hub/depth/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .decode_heads import BNHead, DPTHead
|
| 7 |
+
from .encoder_decoder import DepthEncoderDecoder
|
LHM/models/encoders/dinov2/hub/depth/decode_heads.py
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
from functools import partial
|
| 8 |
+
import math
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from .ops import resize
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# XXX: (Untested) replacement for mmcv.imdenormalize()
|
| 18 |
+
def _imdenormalize(img, mean, std, to_bgr=True):
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
mean = mean.reshape(1, -1).astype(np.float64)
|
| 22 |
+
std = std.reshape(1, -1).astype(np.float64)
|
| 23 |
+
img = (img * std) + mean
|
| 24 |
+
if to_bgr:
|
| 25 |
+
img = img[::-1]
|
| 26 |
+
return img
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DepthBaseDecodeHead(nn.Module):
|
| 30 |
+
"""Base class for BaseDecodeHead.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
in_channels (List): Input channels.
|
| 34 |
+
channels (int): Channels after modules, before conv_depth.
|
| 35 |
+
conv_layer (nn.Module): Conv layers. Default: None.
|
| 36 |
+
act_layer (nn.Module): Activation layers. Default: nn.ReLU.
|
| 37 |
+
loss_decode (dict): Config of decode loss.
|
| 38 |
+
Default: ().
|
| 39 |
+
sampler (dict|None): The config of depth map sampler.
|
| 40 |
+
Default: None.
|
| 41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 42 |
+
Default: False.
|
| 43 |
+
min_depth (int): Min depth in dataset setting.
|
| 44 |
+
Default: 1e-3.
|
| 45 |
+
max_depth (int): Max depth in dataset setting.
|
| 46 |
+
Default: None.
|
| 47 |
+
norm_layer (dict|None): Norm layers.
|
| 48 |
+
Default: None.
|
| 49 |
+
classify (bool): Whether predict depth in a cls.-reg. manner.
|
| 50 |
+
Default: False.
|
| 51 |
+
n_bins (int): The number of bins used in cls. step.
|
| 52 |
+
Default: 256.
|
| 53 |
+
bins_strategy (str): The discrete strategy used in cls. step.
|
| 54 |
+
Default: 'UD'.
|
| 55 |
+
norm_strategy (str): The norm strategy on cls. probability
|
| 56 |
+
distribution. Default: 'linear'
|
| 57 |
+
scale_up (str): Whether predict depth in a scale-up manner.
|
| 58 |
+
Default: False.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
in_channels,
|
| 64 |
+
conv_layer=None,
|
| 65 |
+
act_layer=nn.ReLU,
|
| 66 |
+
channels=96,
|
| 67 |
+
loss_decode=(),
|
| 68 |
+
sampler=None,
|
| 69 |
+
align_corners=False,
|
| 70 |
+
min_depth=1e-3,
|
| 71 |
+
max_depth=None,
|
| 72 |
+
norm_layer=None,
|
| 73 |
+
classify=False,
|
| 74 |
+
n_bins=256,
|
| 75 |
+
bins_strategy="UD",
|
| 76 |
+
norm_strategy="linear",
|
| 77 |
+
scale_up=False,
|
| 78 |
+
):
|
| 79 |
+
super(DepthBaseDecodeHead, self).__init__()
|
| 80 |
+
|
| 81 |
+
self.in_channels = in_channels
|
| 82 |
+
self.channels = channels
|
| 83 |
+
self.conf_layer = conv_layer
|
| 84 |
+
self.act_layer = act_layer
|
| 85 |
+
self.loss_decode = loss_decode
|
| 86 |
+
self.align_corners = align_corners
|
| 87 |
+
self.min_depth = min_depth
|
| 88 |
+
self.max_depth = max_depth
|
| 89 |
+
self.norm_layer = norm_layer
|
| 90 |
+
self.classify = classify
|
| 91 |
+
self.n_bins = n_bins
|
| 92 |
+
self.scale_up = scale_up
|
| 93 |
+
|
| 94 |
+
if self.classify:
|
| 95 |
+
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
| 96 |
+
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
| 97 |
+
|
| 98 |
+
self.bins_strategy = bins_strategy
|
| 99 |
+
self.norm_strategy = norm_strategy
|
| 100 |
+
self.softmax = nn.Softmax(dim=1)
|
| 101 |
+
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
| 102 |
+
else:
|
| 103 |
+
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
| 104 |
+
|
| 105 |
+
self.relu = nn.ReLU()
|
| 106 |
+
self.sigmoid = nn.Sigmoid()
|
| 107 |
+
|
| 108 |
+
def forward(self, inputs, img_metas):
|
| 109 |
+
"""Placeholder of forward function."""
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def forward_train(self, img, inputs, img_metas, depth_gt):
|
| 113 |
+
"""Forward function for training.
|
| 114 |
+
Args:
|
| 115 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 116 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 117 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 118 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 119 |
+
For details on the values of these keys see
|
| 120 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 121 |
+
depth_gt (Tensor): GT depth
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 125 |
+
"""
|
| 126 |
+
depth_pred = self.forward(inputs, img_metas)
|
| 127 |
+
losses = self.losses(depth_pred, depth_gt)
|
| 128 |
+
|
| 129 |
+
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
| 130 |
+
losses.update(**log_imgs)
|
| 131 |
+
|
| 132 |
+
return losses
|
| 133 |
+
|
| 134 |
+
def forward_test(self, inputs, img_metas):
|
| 135 |
+
"""Forward function for testing.
|
| 136 |
+
Args:
|
| 137 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 138 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 139 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 140 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 141 |
+
For details on the values of these keys see
|
| 142 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tensor: Output depth map.
|
| 146 |
+
"""
|
| 147 |
+
return self.forward(inputs, img_metas)
|
| 148 |
+
|
| 149 |
+
def depth_pred(self, feat):
|
| 150 |
+
"""Prediction each pixel."""
|
| 151 |
+
if self.classify:
|
| 152 |
+
logit = self.conv_depth(feat)
|
| 153 |
+
|
| 154 |
+
if self.bins_strategy == "UD":
|
| 155 |
+
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
| 156 |
+
elif self.bins_strategy == "SID":
|
| 157 |
+
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
| 158 |
+
|
| 159 |
+
# following Adabins, default linear
|
| 160 |
+
if self.norm_strategy == "linear":
|
| 161 |
+
logit = torch.relu(logit)
|
| 162 |
+
eps = 0.1
|
| 163 |
+
logit = logit + eps
|
| 164 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
| 165 |
+
elif self.norm_strategy == "softmax":
|
| 166 |
+
logit = torch.softmax(logit, dim=1)
|
| 167 |
+
elif self.norm_strategy == "sigmoid":
|
| 168 |
+
logit = torch.sigmoid(logit)
|
| 169 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
| 170 |
+
|
| 171 |
+
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
if self.scale_up:
|
| 175 |
+
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
| 176 |
+
else:
|
| 177 |
+
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
| 178 |
+
return output
|
| 179 |
+
|
| 180 |
+
def losses(self, depth_pred, depth_gt):
|
| 181 |
+
"""Compute depth loss."""
|
| 182 |
+
loss = dict()
|
| 183 |
+
depth_pred = resize(
|
| 184 |
+
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
| 185 |
+
)
|
| 186 |
+
if not isinstance(self.loss_decode, nn.ModuleList):
|
| 187 |
+
losses_decode = [self.loss_decode]
|
| 188 |
+
else:
|
| 189 |
+
losses_decode = self.loss_decode
|
| 190 |
+
for loss_decode in losses_decode:
|
| 191 |
+
if loss_decode.loss_name not in loss:
|
| 192 |
+
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
| 193 |
+
else:
|
| 194 |
+
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
| 195 |
+
return loss
|
| 196 |
+
|
| 197 |
+
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
| 198 |
+
import numpy as np
|
| 199 |
+
|
| 200 |
+
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
| 201 |
+
show_img = show_img.numpy().astype(np.float32)
|
| 202 |
+
show_img = _imdenormalize(
|
| 203 |
+
show_img,
|
| 204 |
+
img_meta["img_norm_cfg"]["mean"],
|
| 205 |
+
img_meta["img_norm_cfg"]["std"],
|
| 206 |
+
img_meta["img_norm_cfg"]["to_rgb"],
|
| 207 |
+
)
|
| 208 |
+
show_img = np.clip(show_img, 0, 255)
|
| 209 |
+
show_img = show_img.astype(np.uint8)
|
| 210 |
+
show_img = show_img[:, :, ::-1]
|
| 211 |
+
show_img = show_img.transpose(0, 2, 1)
|
| 212 |
+
show_img = show_img.transpose(1, 0, 2)
|
| 213 |
+
|
| 214 |
+
depth_pred = depth_pred / torch.max(depth_pred)
|
| 215 |
+
depth_gt = depth_gt / torch.max(depth_gt)
|
| 216 |
+
|
| 217 |
+
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
| 218 |
+
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
| 219 |
+
|
| 220 |
+
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class BNHead(DepthBaseDecodeHead):
|
| 224 |
+
"""Just a batchnorm."""
|
| 225 |
+
|
| 226 |
+
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
| 227 |
+
super().__init__(**kwargs)
|
| 228 |
+
self.input_transform = input_transform
|
| 229 |
+
self.in_index = in_index
|
| 230 |
+
self.upsample = upsample
|
| 231 |
+
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
| 232 |
+
if self.classify:
|
| 233 |
+
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
| 234 |
+
else:
|
| 235 |
+
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
| 236 |
+
|
| 237 |
+
def _transform_inputs(self, inputs):
|
| 238 |
+
"""Transform inputs for decoder.
|
| 239 |
+
Args:
|
| 240 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 241 |
+
Returns:
|
| 242 |
+
Tensor: The transformed inputs
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
if "concat" in self.input_transform:
|
| 246 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 247 |
+
if "resize" in self.input_transform:
|
| 248 |
+
inputs = [
|
| 249 |
+
resize(
|
| 250 |
+
input=x,
|
| 251 |
+
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
| 252 |
+
mode="bilinear",
|
| 253 |
+
align_corners=self.align_corners,
|
| 254 |
+
)
|
| 255 |
+
for x in inputs
|
| 256 |
+
]
|
| 257 |
+
inputs = torch.cat(inputs, dim=1)
|
| 258 |
+
elif self.input_transform == "multiple_select":
|
| 259 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 260 |
+
else:
|
| 261 |
+
inputs = inputs[self.in_index]
|
| 262 |
+
|
| 263 |
+
return inputs
|
| 264 |
+
|
| 265 |
+
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
| 266 |
+
"""Forward function for feature maps before classifying each pixel with
|
| 267 |
+
``self.cls_seg`` fc.
|
| 268 |
+
Args:
|
| 269 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 270 |
+
Returns:
|
| 271 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
| 272 |
+
H, W) which is feature map for last layer of decoder head.
|
| 273 |
+
"""
|
| 274 |
+
# accept lists (for cls token)
|
| 275 |
+
inputs = list(inputs)
|
| 276 |
+
for i, x in enumerate(inputs):
|
| 277 |
+
if len(x) == 2:
|
| 278 |
+
x, cls_token = x[0], x[1]
|
| 279 |
+
if len(x.shape) == 2:
|
| 280 |
+
x = x[:, :, None, None]
|
| 281 |
+
cls_token = cls_token[:, :, None, None].expand_as(x)
|
| 282 |
+
inputs[i] = torch.cat((x, cls_token), 1)
|
| 283 |
+
else:
|
| 284 |
+
x = x[0]
|
| 285 |
+
if len(x.shape) == 2:
|
| 286 |
+
x = x[:, :, None, None]
|
| 287 |
+
inputs[i] = x
|
| 288 |
+
x = self._transform_inputs(inputs)
|
| 289 |
+
# feats = self.bn(x)
|
| 290 |
+
return x
|
| 291 |
+
|
| 292 |
+
def forward(self, inputs, img_metas=None, **kwargs):
|
| 293 |
+
"""Forward function."""
|
| 294 |
+
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
| 295 |
+
output = self.depth_pred(output)
|
| 296 |
+
return output
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class ConvModule(nn.Module):
|
| 300 |
+
"""A conv block that bundles conv/norm/activation layers.
|
| 301 |
+
|
| 302 |
+
This block simplifies the usage of convolution layers, which are commonly
|
| 303 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
| 304 |
+
It is based upon three build methods: `build_conv_layer()`,
|
| 305 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
| 306 |
+
|
| 307 |
+
Besides, we add some additional features in this module.
|
| 308 |
+
1. Automatically set `bias` of the conv layer.
|
| 309 |
+
2. Spectral norm is supported.
|
| 310 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
| 311 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
in_channels (int): Number of channels in the input feature map.
|
| 315 |
+
Same as that in ``nn._ConvNd``.
|
| 316 |
+
out_channels (int): Number of channels produced by the convolution.
|
| 317 |
+
Same as that in ``nn._ConvNd``.
|
| 318 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
| 319 |
+
Same as that in ``nn._ConvNd``.
|
| 320 |
+
stride (int | tuple[int]): Stride of the convolution.
|
| 321 |
+
Same as that in ``nn._ConvNd``.
|
| 322 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
| 323 |
+
the input. Same as that in ``nn._ConvNd``.
|
| 324 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
| 325 |
+
Same as that in ``nn._ConvNd``.
|
| 326 |
+
groups (int): Number of blocked connections from input channels to
|
| 327 |
+
output channels. Same as that in ``nn._ConvNd``.
|
| 328 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
| 329 |
+
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
|
| 330 |
+
False. Default: "auto".
|
| 331 |
+
conv_layer (nn.Module): Convolution layer. Default: None,
|
| 332 |
+
which means using conv2d.
|
| 333 |
+
norm_layer (nn.Module): Normalization layer. Default: None.
|
| 334 |
+
act_layer (nn.Module): Activation layer. Default: nn.ReLU.
|
| 335 |
+
inplace (bool): Whether to use inplace mode for activation.
|
| 336 |
+
Default: True.
|
| 337 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
| 338 |
+
Default: False.
|
| 339 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
| 340 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
| 341 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
| 342 |
+
implementation and ['reflect'] with our own implementation.
|
| 343 |
+
Default: 'zeros'.
|
| 344 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
| 345 |
+
sequence of "conv", "norm" and "act". Common examples are
|
| 346 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
| 347 |
+
Default: ('conv', 'norm', 'act').
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
_abbr_ = "conv_block"
|
| 351 |
+
|
| 352 |
+
def __init__(
|
| 353 |
+
self,
|
| 354 |
+
in_channels,
|
| 355 |
+
out_channels,
|
| 356 |
+
kernel_size,
|
| 357 |
+
stride=1,
|
| 358 |
+
padding=0,
|
| 359 |
+
dilation=1,
|
| 360 |
+
groups=1,
|
| 361 |
+
bias="auto",
|
| 362 |
+
conv_layer=nn.Conv2d,
|
| 363 |
+
norm_layer=None,
|
| 364 |
+
act_layer=nn.ReLU,
|
| 365 |
+
inplace=True,
|
| 366 |
+
with_spectral_norm=False,
|
| 367 |
+
padding_mode="zeros",
|
| 368 |
+
order=("conv", "norm", "act"),
|
| 369 |
+
):
|
| 370 |
+
super(ConvModule, self).__init__()
|
| 371 |
+
official_padding_mode = ["zeros", "circular"]
|
| 372 |
+
self.conv_layer = conv_layer
|
| 373 |
+
self.norm_layer = norm_layer
|
| 374 |
+
self.act_layer = act_layer
|
| 375 |
+
self.inplace = inplace
|
| 376 |
+
self.with_spectral_norm = with_spectral_norm
|
| 377 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
| 378 |
+
self.order = order
|
| 379 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
| 380 |
+
assert set(order) == set(["conv", "norm", "act"])
|
| 381 |
+
|
| 382 |
+
self.with_norm = norm_layer is not None
|
| 383 |
+
self.with_activation = act_layer is not None
|
| 384 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
| 385 |
+
if bias == "auto":
|
| 386 |
+
bias = not self.with_norm
|
| 387 |
+
self.with_bias = bias
|
| 388 |
+
|
| 389 |
+
if self.with_explicit_padding:
|
| 390 |
+
if padding_mode == "zeros":
|
| 391 |
+
padding_layer = nn.ZeroPad2d
|
| 392 |
+
else:
|
| 393 |
+
raise AssertionError(f"Unsupported padding mode: {padding_mode}")
|
| 394 |
+
self.pad = padding_layer(padding)
|
| 395 |
+
|
| 396 |
+
# reset padding to 0 for conv module
|
| 397 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
| 398 |
+
# build convolution layer
|
| 399 |
+
self.conv = self.conv_layer(
|
| 400 |
+
in_channels,
|
| 401 |
+
out_channels,
|
| 402 |
+
kernel_size,
|
| 403 |
+
stride=stride,
|
| 404 |
+
padding=conv_padding,
|
| 405 |
+
dilation=dilation,
|
| 406 |
+
groups=groups,
|
| 407 |
+
bias=bias,
|
| 408 |
+
)
|
| 409 |
+
# export the attributes of self.conv to a higher level for convenience
|
| 410 |
+
self.in_channels = self.conv.in_channels
|
| 411 |
+
self.out_channels = self.conv.out_channels
|
| 412 |
+
self.kernel_size = self.conv.kernel_size
|
| 413 |
+
self.stride = self.conv.stride
|
| 414 |
+
self.padding = padding
|
| 415 |
+
self.dilation = self.conv.dilation
|
| 416 |
+
self.transposed = self.conv.transposed
|
| 417 |
+
self.output_padding = self.conv.output_padding
|
| 418 |
+
self.groups = self.conv.groups
|
| 419 |
+
|
| 420 |
+
if self.with_spectral_norm:
|
| 421 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
| 422 |
+
|
| 423 |
+
# build normalization layers
|
| 424 |
+
if self.with_norm:
|
| 425 |
+
# norm layer is after conv layer
|
| 426 |
+
if order.index("norm") > order.index("conv"):
|
| 427 |
+
norm_channels = out_channels
|
| 428 |
+
else:
|
| 429 |
+
norm_channels = in_channels
|
| 430 |
+
norm = partial(norm_layer, num_features=norm_channels)
|
| 431 |
+
self.add_module("norm", norm)
|
| 432 |
+
if self.with_bias:
|
| 433 |
+
from torch.nnModules.batchnorm import _BatchNorm
|
| 434 |
+
from torch.nnModules.instancenorm import _InstanceNorm
|
| 435 |
+
|
| 436 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
| 437 |
+
warnings.warn("Unnecessary conv bias before batch/instance norm")
|
| 438 |
+
else:
|
| 439 |
+
self.norm_name = None
|
| 440 |
+
|
| 441 |
+
# build activation layer
|
| 442 |
+
if self.with_activation:
|
| 443 |
+
# nn.Tanh has no 'inplace' argument
|
| 444 |
+
# (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
|
| 445 |
+
if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
|
| 446 |
+
act_layer = partial(act_layer, inplace=inplace)
|
| 447 |
+
self.activate = act_layer()
|
| 448 |
+
|
| 449 |
+
# Use msra init by default
|
| 450 |
+
self.init_weights()
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def norm(self):
|
| 454 |
+
if self.norm_name:
|
| 455 |
+
return getattr(self, self.norm_name)
|
| 456 |
+
else:
|
| 457 |
+
return None
|
| 458 |
+
|
| 459 |
+
def init_weights(self):
|
| 460 |
+
# 1. It is mainly for customized conv layers with their own
|
| 461 |
+
# initialization manners by calling their own ``init_weights()``,
|
| 462 |
+
# and we do not want ConvModule to override the initialization.
|
| 463 |
+
# 2. For customized conv layers without their own initialization
|
| 464 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
| 465 |
+
# and PyTorch's conv layers, they will be initialized by
|
| 466 |
+
# this method with default ``kaiming_init``.
|
| 467 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
| 468 |
+
# initialization implementation using default ``kaiming_init``.
|
| 469 |
+
if not hasattr(self.conv, "init_weights"):
|
| 470 |
+
if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
|
| 471 |
+
nonlinearity = "leaky_relu"
|
| 472 |
+
a = 0.01 # XXX: default negative_slope
|
| 473 |
+
else:
|
| 474 |
+
nonlinearity = "relu"
|
| 475 |
+
a = 0
|
| 476 |
+
if hasattr(self.conv, "weight") and self.conv.weight is not None:
|
| 477 |
+
nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
|
| 478 |
+
if hasattr(self.conv, "bias") and self.conv.bias is not None:
|
| 479 |
+
nn.init.constant_(self.conv.bias, 0)
|
| 480 |
+
if self.with_norm:
|
| 481 |
+
if hasattr(self.norm, "weight") and self.norm.weight is not None:
|
| 482 |
+
nn.init.constant_(self.norm.weight, 1)
|
| 483 |
+
if hasattr(self.norm, "bias") and self.norm.bias is not None:
|
| 484 |
+
nn.init.constant_(self.norm.bias, 0)
|
| 485 |
+
|
| 486 |
+
def forward(self, x, activate=True, norm=True):
|
| 487 |
+
for layer in self.order:
|
| 488 |
+
if layer == "conv":
|
| 489 |
+
if self.with_explicit_padding:
|
| 490 |
+
x = self.pad(x)
|
| 491 |
+
x = self.conv(x)
|
| 492 |
+
elif layer == "norm" and norm and self.with_norm:
|
| 493 |
+
x = self.norm(x)
|
| 494 |
+
elif layer == "act" and activate and self.with_activation:
|
| 495 |
+
x = self.activate(x)
|
| 496 |
+
return x
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class Interpolate(nn.Module):
|
| 500 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
| 501 |
+
super(Interpolate, self).__init__()
|
| 502 |
+
self.interp = nn.functional.interpolate
|
| 503 |
+
self.scale_factor = scale_factor
|
| 504 |
+
self.mode = mode
|
| 505 |
+
self.align_corners = align_corners
|
| 506 |
+
|
| 507 |
+
def forward(self, x):
|
| 508 |
+
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class HeadDepth(nn.Module):
|
| 513 |
+
def __init__(self, features):
|
| 514 |
+
super(HeadDepth, self).__init__()
|
| 515 |
+
self.head = nn.Sequential(
|
| 516 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
| 517 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 518 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
| 519 |
+
nn.ReLU(),
|
| 520 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def forward(self, x):
|
| 524 |
+
x = self.head(x)
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class ReassembleBlocks(nn.Module):
|
| 529 |
+
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
| 530 |
+
rearrange the feature vector to feature map.
|
| 531 |
+
Args:
|
| 532 |
+
in_channels (int): ViT feature channels. Default: 768.
|
| 533 |
+
out_channels (List): output channels of each stage.
|
| 534 |
+
Default: [96, 192, 384, 768].
|
| 535 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
| 536 |
+
patch_size (int): The patch size. Default: 16.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
|
| 540 |
+
super(ReassembleBlocks, self).__init__()
|
| 541 |
+
|
| 542 |
+
assert readout_type in ["ignore", "add", "project"]
|
| 543 |
+
self.readout_type = readout_type
|
| 544 |
+
self.patch_size = patch_size
|
| 545 |
+
|
| 546 |
+
self.projects = nn.ModuleList(
|
| 547 |
+
[
|
| 548 |
+
ConvModule(
|
| 549 |
+
in_channels=in_channels,
|
| 550 |
+
out_channels=out_channel,
|
| 551 |
+
kernel_size=1,
|
| 552 |
+
act_layer=None,
|
| 553 |
+
)
|
| 554 |
+
for out_channel in out_channels
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self.resize_layers = nn.ModuleList(
|
| 559 |
+
[
|
| 560 |
+
nn.ConvTranspose2d(
|
| 561 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 562 |
+
),
|
| 563 |
+
nn.ConvTranspose2d(
|
| 564 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 565 |
+
),
|
| 566 |
+
nn.Identity(),
|
| 567 |
+
nn.Conv2d(
|
| 568 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 569 |
+
),
|
| 570 |
+
]
|
| 571 |
+
)
|
| 572 |
+
if self.readout_type == "project":
|
| 573 |
+
self.readout_projects = nn.ModuleList()
|
| 574 |
+
for _ in range(len(self.projects)):
|
| 575 |
+
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
| 576 |
+
|
| 577 |
+
def forward(self, inputs):
|
| 578 |
+
assert isinstance(inputs, list)
|
| 579 |
+
out = []
|
| 580 |
+
for i, x in enumerate(inputs):
|
| 581 |
+
assert len(x) == 2
|
| 582 |
+
x, cls_token = x[0], x[1]
|
| 583 |
+
feature_shape = x.shape
|
| 584 |
+
if self.readout_type == "project":
|
| 585 |
+
x = x.flatten(2).permute((0, 2, 1))
|
| 586 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 587 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 588 |
+
x = x.permute(0, 2, 1).reshape(feature_shape)
|
| 589 |
+
elif self.readout_type == "add":
|
| 590 |
+
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
| 591 |
+
x = x.reshape(feature_shape)
|
| 592 |
+
else:
|
| 593 |
+
pass
|
| 594 |
+
x = self.projects[i](x)
|
| 595 |
+
x = self.resize_layers[i](x)
|
| 596 |
+
out.append(x)
|
| 597 |
+
return out
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class PreActResidualConvUnit(nn.Module):
|
| 601 |
+
"""ResidualConvUnit, pre-activate residual unit.
|
| 602 |
+
Args:
|
| 603 |
+
in_channels (int): number of channels in the input feature map.
|
| 604 |
+
act_layer (nn.Module): activation layer.
|
| 605 |
+
norm_layer (nn.Module): norm layer.
|
| 606 |
+
stride (int): stride of the first block. Default: 1
|
| 607 |
+
dilation (int): dilation rate for convs layers. Default: 1.
|
| 608 |
+
"""
|
| 609 |
+
|
| 610 |
+
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
|
| 611 |
+
super(PreActResidualConvUnit, self).__init__()
|
| 612 |
+
|
| 613 |
+
self.conv1 = ConvModule(
|
| 614 |
+
in_channels,
|
| 615 |
+
in_channels,
|
| 616 |
+
3,
|
| 617 |
+
stride=stride,
|
| 618 |
+
padding=dilation,
|
| 619 |
+
dilation=dilation,
|
| 620 |
+
norm_layer=norm_layer,
|
| 621 |
+
act_layer=act_layer,
|
| 622 |
+
bias=False,
|
| 623 |
+
order=("act", "conv", "norm"),
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
self.conv2 = ConvModule(
|
| 627 |
+
in_channels,
|
| 628 |
+
in_channels,
|
| 629 |
+
3,
|
| 630 |
+
padding=1,
|
| 631 |
+
norm_layer=norm_layer,
|
| 632 |
+
act_layer=act_layer,
|
| 633 |
+
bias=False,
|
| 634 |
+
order=("act", "conv", "norm"),
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
def forward(self, inputs):
|
| 638 |
+
inputs_ = inputs.clone()
|
| 639 |
+
x = self.conv1(inputs)
|
| 640 |
+
x = self.conv2(x)
|
| 641 |
+
return x + inputs_
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
class FeatureFusionBlock(nn.Module):
|
| 645 |
+
"""FeatureFusionBlock, merge feature map from different stages.
|
| 646 |
+
Args:
|
| 647 |
+
in_channels (int): Input channels.
|
| 648 |
+
act_layer (nn.Module): activation layer for ResidualConvUnit.
|
| 649 |
+
norm_layer (nn.Module): normalization layer.
|
| 650 |
+
expand (bool): Whether expand the channels in post process block.
|
| 651 |
+
Default: False.
|
| 652 |
+
align_corners (bool): align_corner setting for bilinear upsample.
|
| 653 |
+
Default: True.
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
|
| 657 |
+
super(FeatureFusionBlock, self).__init__()
|
| 658 |
+
|
| 659 |
+
self.in_channels = in_channels
|
| 660 |
+
self.expand = expand
|
| 661 |
+
self.align_corners = align_corners
|
| 662 |
+
|
| 663 |
+
self.out_channels = in_channels
|
| 664 |
+
if self.expand:
|
| 665 |
+
self.out_channels = in_channels // 2
|
| 666 |
+
|
| 667 |
+
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
|
| 668 |
+
|
| 669 |
+
self.res_conv_unit1 = PreActResidualConvUnit(
|
| 670 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
| 671 |
+
)
|
| 672 |
+
self.res_conv_unit2 = PreActResidualConvUnit(
|
| 673 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
def forward(self, *inputs):
|
| 677 |
+
x = inputs[0]
|
| 678 |
+
if len(inputs) == 2:
|
| 679 |
+
if x.shape != inputs[1].shape:
|
| 680 |
+
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
| 681 |
+
else:
|
| 682 |
+
res = inputs[1]
|
| 683 |
+
x = x + self.res_conv_unit1(res)
|
| 684 |
+
x = self.res_conv_unit2(x)
|
| 685 |
+
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
| 686 |
+
x = self.project(x)
|
| 687 |
+
return x
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class DPTHead(DepthBaseDecodeHead):
|
| 691 |
+
"""Vision Transformers for Dense Prediction.
|
| 692 |
+
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
| 693 |
+
Args:
|
| 694 |
+
embed_dims (int): The embed dimension of the ViT backbone.
|
| 695 |
+
Default: 768.
|
| 696 |
+
post_process_channels (List): Out channels of post process conv
|
| 697 |
+
layers. Default: [96, 192, 384, 768].
|
| 698 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
| 699 |
+
patch_size (int): The patch size. Default: 16.
|
| 700 |
+
expand_channels (bool): Whether expand the channels in post process
|
| 701 |
+
block. Default: False.
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
embed_dims=768,
|
| 707 |
+
post_process_channels=[96, 192, 384, 768],
|
| 708 |
+
readout_type="ignore",
|
| 709 |
+
patch_size=16,
|
| 710 |
+
expand_channels=False,
|
| 711 |
+
**kwargs,
|
| 712 |
+
):
|
| 713 |
+
super(DPTHead, self).__init__(**kwargs)
|
| 714 |
+
|
| 715 |
+
self.in_channels = self.in_channels
|
| 716 |
+
self.expand_channels = expand_channels
|
| 717 |
+
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
| 718 |
+
|
| 719 |
+
self.post_process_channels = [
|
| 720 |
+
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
| 721 |
+
]
|
| 722 |
+
self.convs = nn.ModuleList()
|
| 723 |
+
for channel in self.post_process_channels:
|
| 724 |
+
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
|
| 725 |
+
self.fusion_blocks = nn.ModuleList()
|
| 726 |
+
for _ in range(len(self.convs)):
|
| 727 |
+
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
|
| 728 |
+
self.fusion_blocks[0].res_conv_unit1 = None
|
| 729 |
+
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
|
| 730 |
+
self.num_fusion_blocks = len(self.fusion_blocks)
|
| 731 |
+
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
| 732 |
+
self.num_post_process_channels = len(self.post_process_channels)
|
| 733 |
+
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
| 734 |
+
assert self.num_reassemble_blocks == self.num_post_process_channels
|
| 735 |
+
self.conv_depth = HeadDepth(self.channels)
|
| 736 |
+
|
| 737 |
+
def forward(self, inputs, img_metas):
|
| 738 |
+
assert len(inputs) == self.num_reassemble_blocks
|
| 739 |
+
x = [inp for inp in inputs]
|
| 740 |
+
x = self.reassemble_blocks(x)
|
| 741 |
+
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
| 742 |
+
out = self.fusion_blocks[0](x[-1])
|
| 743 |
+
for i in range(1, len(self.fusion_blocks)):
|
| 744 |
+
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
| 745 |
+
out = self.project(out)
|
| 746 |
+
out = self.depth_pred(out)
|
| 747 |
+
return out
|
LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from .ops import resize
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def add_prefix(inputs, prefix):
|
| 16 |
+
"""Add prefix for dict.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
inputs (dict): The input dict with str keys.
|
| 20 |
+
prefix (str): The prefix to add.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
|
| 24 |
+
dict: The dict with keys updated with ``prefix``.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
outputs = dict()
|
| 28 |
+
for name, value in inputs.items():
|
| 29 |
+
outputs[f"{prefix}.{name}"] = value
|
| 30 |
+
|
| 31 |
+
return outputs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DepthEncoderDecoder(nn.Module):
|
| 35 |
+
"""Encoder Decoder depther.
|
| 36 |
+
|
| 37 |
+
EncoderDecoder typically consists of backbone and decode_head.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, backbone, decode_head):
|
| 41 |
+
super(DepthEncoderDecoder, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.backbone = backbone
|
| 44 |
+
self.decode_head = decode_head
|
| 45 |
+
self.align_corners = self.decode_head.align_corners
|
| 46 |
+
|
| 47 |
+
def extract_feat(self, img):
|
| 48 |
+
"""Extract features from images."""
|
| 49 |
+
return self.backbone(img)
|
| 50 |
+
|
| 51 |
+
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
| 52 |
+
"""Encode images with backbone and decode into a depth estimation
|
| 53 |
+
map of the same size as input."""
|
| 54 |
+
x = self.extract_feat(img)
|
| 55 |
+
out = self._decode_head_forward_test(x, img_metas)
|
| 56 |
+
# crop the pred depth to the certain range.
|
| 57 |
+
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
| 58 |
+
if rescale:
|
| 59 |
+
if size is None:
|
| 60 |
+
if img_metas is not None:
|
| 61 |
+
size = img_metas[0]["ori_shape"][:2]
|
| 62 |
+
else:
|
| 63 |
+
size = img.shape[2:]
|
| 64 |
+
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
| 68 |
+
"""Run forward function and calculate loss for decode head in
|
| 69 |
+
training."""
|
| 70 |
+
losses = dict()
|
| 71 |
+
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
|
| 72 |
+
losses.update(add_prefix(loss_decode, "decode"))
|
| 73 |
+
return losses
|
| 74 |
+
|
| 75 |
+
def _decode_head_forward_test(self, x, img_metas):
|
| 76 |
+
"""Run forward function and calculate loss for decode head in
|
| 77 |
+
inference."""
|
| 78 |
+
depth_pred = self.decode_head.forward_test(x, img_metas)
|
| 79 |
+
return depth_pred
|
| 80 |
+
|
| 81 |
+
def forward_dummy(self, img):
|
| 82 |
+
"""Dummy forward function."""
|
| 83 |
+
depth = self.encode_decode(img, None)
|
| 84 |
+
|
| 85 |
+
return depth
|
| 86 |
+
|
| 87 |
+
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
| 88 |
+
"""Forward function for training.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
img (Tensor): Input images.
|
| 92 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 93 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 94 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 95 |
+
For details on the values of these keys see
|
| 96 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 97 |
+
depth_gt (Tensor): Depth gt
|
| 98 |
+
used if the architecture supports depth estimation task.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
x = self.extract_feat(img)
|
| 105 |
+
|
| 106 |
+
losses = dict()
|
| 107 |
+
|
| 108 |
+
# the last of x saves the info from neck
|
| 109 |
+
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
| 110 |
+
|
| 111 |
+
losses.update(loss_decode)
|
| 112 |
+
|
| 113 |
+
return losses
|
| 114 |
+
|
| 115 |
+
def whole_inference(self, img, img_meta, rescale, size=None):
|
| 116 |
+
"""Inference with full image."""
|
| 117 |
+
return self.encode_decode(img, img_meta, rescale, size=size)
|
| 118 |
+
|
| 119 |
+
def slide_inference(self, img, img_meta, rescale, stride, crop_size):
|
| 120 |
+
"""Inference by sliding-window with overlap.
|
| 121 |
+
|
| 122 |
+
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
| 123 |
+
decode without padding.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
h_stride, w_stride = stride
|
| 127 |
+
h_crop, w_crop = crop_size
|
| 128 |
+
batch_size, _, h_img, w_img = img.size()
|
| 129 |
+
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
| 130 |
+
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
| 131 |
+
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
| 132 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
| 133 |
+
for h_idx in range(h_grids):
|
| 134 |
+
for w_idx in range(w_grids):
|
| 135 |
+
y1 = h_idx * h_stride
|
| 136 |
+
x1 = w_idx * w_stride
|
| 137 |
+
y2 = min(y1 + h_crop, h_img)
|
| 138 |
+
x2 = min(x1 + w_crop, w_img)
|
| 139 |
+
y1 = max(y2 - h_crop, 0)
|
| 140 |
+
x1 = max(x2 - w_crop, 0)
|
| 141 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
| 142 |
+
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
| 143 |
+
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
| 144 |
+
|
| 145 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
| 146 |
+
assert (count_mat == 0).sum() == 0
|
| 147 |
+
if torch.onnx.is_in_onnx_export():
|
| 148 |
+
# cast count_mat to constant while exporting to ONNX
|
| 149 |
+
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
| 150 |
+
preds = preds / count_mat
|
| 151 |
+
return preds
|
| 152 |
+
|
| 153 |
+
def inference(self, img, img_meta, rescale, size=None, mode="whole"):
|
| 154 |
+
"""Inference with slide/whole style.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
img (Tensor): The input image of shape (N, 3, H, W).
|
| 158 |
+
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
| 159 |
+
'scale_factor', 'flip', and may also contain
|
| 160 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 161 |
+
For details on the values of these keys see
|
| 162 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
| 163 |
+
rescale (bool): Whether rescale back to original shape.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Tensor: The output depth map.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
assert mode in ["slide", "whole"]
|
| 170 |
+
ori_shape = img_meta[0]["ori_shape"]
|
| 171 |
+
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
| 172 |
+
if mode == "slide":
|
| 173 |
+
depth_pred = self.slide_inference(img, img_meta, rescale)
|
| 174 |
+
else:
|
| 175 |
+
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
| 176 |
+
output = depth_pred
|
| 177 |
+
flip = img_meta[0]["flip"]
|
| 178 |
+
if flip:
|
| 179 |
+
flip_direction = img_meta[0]["flip_direction"]
|
| 180 |
+
assert flip_direction in ["horizontal", "vertical"]
|
| 181 |
+
if flip_direction == "horizontal":
|
| 182 |
+
output = output.flip(dims=(3,))
|
| 183 |
+
elif flip_direction == "vertical":
|
| 184 |
+
output = output.flip(dims=(2,))
|
| 185 |
+
|
| 186 |
+
return output
|
| 187 |
+
|
| 188 |
+
def simple_test(self, img, img_meta, rescale=True):
|
| 189 |
+
"""Simple test with single image."""
|
| 190 |
+
depth_pred = self.inference(img, img_meta, rescale)
|
| 191 |
+
if torch.onnx.is_in_onnx_export():
|
| 192 |
+
# our inference backend only support 4D output
|
| 193 |
+
depth_pred = depth_pred.unsqueeze(0)
|
| 194 |
+
return depth_pred
|
| 195 |
+
depth_pred = depth_pred.cpu().numpy()
|
| 196 |
+
# unravel batch dim
|
| 197 |
+
depth_pred = list(depth_pred)
|
| 198 |
+
return depth_pred
|
| 199 |
+
|
| 200 |
+
def aug_test(self, imgs, img_metas, rescale=True):
|
| 201 |
+
"""Test with augmentations.
|
| 202 |
+
|
| 203 |
+
Only rescale=True is supported.
|
| 204 |
+
"""
|
| 205 |
+
# aug_test rescale all imgs back to ori_shape for now
|
| 206 |
+
assert rescale
|
| 207 |
+
# to save memory, we get augmented depth logit inplace
|
| 208 |
+
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
| 209 |
+
for i in range(1, len(imgs)):
|
| 210 |
+
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
| 211 |
+
depth_pred += cur_depth_pred
|
| 212 |
+
depth_pred /= len(imgs)
|
| 213 |
+
depth_pred = depth_pred.cpu().numpy()
|
| 214 |
+
# unravel batch dim
|
| 215 |
+
depth_pred = list(depth_pred)
|
| 216 |
+
return depth_pred
|
| 217 |
+
|
| 218 |
+
def forward_test(self, imgs, img_metas, **kwargs):
|
| 219 |
+
"""
|
| 220 |
+
Args:
|
| 221 |
+
imgs (List[Tensor]): the outer list indicates test-time
|
| 222 |
+
augmentations and inner Tensor should have a shape NxCxHxW,
|
| 223 |
+
which contains all images in the batch.
|
| 224 |
+
img_metas (List[List[dict]]): the outer list indicates test-time
|
| 225 |
+
augs (multiscale, flip, etc.) and the inner list indicates
|
| 226 |
+
images in a batch.
|
| 227 |
+
"""
|
| 228 |
+
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
| 229 |
+
if not isinstance(var, list):
|
| 230 |
+
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
| 231 |
+
num_augs = len(imgs)
|
| 232 |
+
if num_augs != len(img_metas):
|
| 233 |
+
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
| 234 |
+
# all images in the same aug batch all of the same ori_shape and pad
|
| 235 |
+
# shape
|
| 236 |
+
for img_meta in img_metas:
|
| 237 |
+
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
| 238 |
+
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
| 239 |
+
img_shapes = [_["img_shape"] for _ in img_meta]
|
| 240 |
+
assert all(shape == img_shapes[0] for shape in img_shapes)
|
| 241 |
+
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
| 242 |
+
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
| 243 |
+
|
| 244 |
+
if num_augs == 1:
|
| 245 |
+
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
| 246 |
+
else:
|
| 247 |
+
return self.aug_test(imgs, img_metas, **kwargs)
|
| 248 |
+
|
| 249 |
+
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
| 250 |
+
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
| 251 |
+
on whether ``return_loss`` is ``True``.
|
| 252 |
+
|
| 253 |
+
Note this setting will change the expected inputs. When
|
| 254 |
+
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
| 255 |
+
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
| 256 |
+
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
| 257 |
+
the outer list indicating test time augmentations.
|
| 258 |
+
"""
|
| 259 |
+
if return_loss:
|
| 260 |
+
return self.forward_train(img, img_metas, **kwargs)
|
| 261 |
+
else:
|
| 262 |
+
return self.forward_test(img, img_metas, **kwargs)
|
| 263 |
+
|
| 264 |
+
def train_step(self, data_batch, optimizer, **kwargs):
|
| 265 |
+
"""The iteration step during training.
|
| 266 |
+
|
| 267 |
+
This method defines an iteration step during training, except for the
|
| 268 |
+
back propagation and optimizer updating, which are done in an optimizer
|
| 269 |
+
hook. Note that in some complicated cases or models, the whole process
|
| 270 |
+
including back propagation and optimizer updating is also defined in
|
| 271 |
+
this method, such as GAN.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
data (dict): The output of dataloader.
|
| 275 |
+
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
| 276 |
+
runner is passed to ``train_step()``. This argument is unused
|
| 277 |
+
and reserved.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
| 281 |
+
``num_samples``.
|
| 282 |
+
``loss`` is a tensor for back propagation, which can be a
|
| 283 |
+
weighted sum of multiple losses.
|
| 284 |
+
``log_vars`` contains all the variables to be sent to the
|
| 285 |
+
logger.
|
| 286 |
+
``num_samples`` indicates the batch size (when the model is
|
| 287 |
+
DDP, it means the batch size on each GPU), which is used for
|
| 288 |
+
averaging the logs.
|
| 289 |
+
"""
|
| 290 |
+
losses = self(**data_batch)
|
| 291 |
+
|
| 292 |
+
# split losses and images
|
| 293 |
+
real_losses = {}
|
| 294 |
+
log_imgs = {}
|
| 295 |
+
for k, v in losses.items():
|
| 296 |
+
if "img" in k:
|
| 297 |
+
log_imgs[k] = v
|
| 298 |
+
else:
|
| 299 |
+
real_losses[k] = v
|
| 300 |
+
|
| 301 |
+
loss, log_vars = self._parse_losses(real_losses)
|
| 302 |
+
|
| 303 |
+
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
| 304 |
+
|
| 305 |
+
return outputs
|
| 306 |
+
|
| 307 |
+
def val_step(self, data_batch, **kwargs):
|
| 308 |
+
"""The iteration step during validation.
|
| 309 |
+
|
| 310 |
+
This method shares the same signature as :func:`train_step`, but used
|
| 311 |
+
during val epochs. Note that the evaluation after training epochs is
|
| 312 |
+
not implemented with this method, but an evaluation hook.
|
| 313 |
+
"""
|
| 314 |
+
output = self(**data_batch, **kwargs)
|
| 315 |
+
return output
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def _parse_losses(losses):
|
| 319 |
+
import torch.distributed as dist
|
| 320 |
+
|
| 321 |
+
"""Parse the raw outputs (losses) of the network.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
losses (dict): Raw output of the network, which usually contain
|
| 325 |
+
losses and other necessary information.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
| 329 |
+
which may be a weighted sum of all losses, log_vars contains
|
| 330 |
+
all the variables to be sent to the logger.
|
| 331 |
+
"""
|
| 332 |
+
log_vars = OrderedDict()
|
| 333 |
+
for loss_name, loss_value in losses.items():
|
| 334 |
+
if isinstance(loss_value, torch.Tensor):
|
| 335 |
+
log_vars[loss_name] = loss_value.mean()
|
| 336 |
+
elif isinstance(loss_value, list):
|
| 337 |
+
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
| 338 |
+
else:
|
| 339 |
+
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
| 340 |
+
|
| 341 |
+
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
| 342 |
+
|
| 343 |
+
log_vars["loss"] = loss
|
| 344 |
+
for loss_name, loss_value in log_vars.items():
|
| 345 |
+
# reduce loss when distributed training
|
| 346 |
+
if dist.is_available() and dist.is_initialized():
|
| 347 |
+
loss_value = loss_value.data.clone()
|
| 348 |
+
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
| 349 |
+
log_vars[loss_name] = loss_value.item()
|
| 350 |
+
|
| 351 |
+
return loss, log_vars
|
LHM/models/encoders/dinov2/hub/depth/ops.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
|
| 12 |
+
if warning:
|
| 13 |
+
if size is not None and align_corners:
|
| 14 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
| 15 |
+
output_h, output_w = tuple(int(x) for x in size)
|
| 16 |
+
if output_h > input_h or output_w > output_h:
|
| 17 |
+
if (
|
| 18 |
+
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
|
| 19 |
+
and (output_h - 1) % (input_h - 1)
|
| 20 |
+
and (output_w - 1) % (input_w - 1)
|
| 21 |
+
):
|
| 22 |
+
warnings.warn(
|
| 23 |
+
f"When align_corners={align_corners}, "
|
| 24 |
+
"the output would more aligned if "
|
| 25 |
+
f"input size {(input_h, input_w)} is `x+1` and "
|
| 26 |
+
f"out size {(output_h, output_w)} is `nx+1`"
|
| 27 |
+
)
|
| 28 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
LHM/models/encoders/dinov2/hub/depthers.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .backbones import _make_dinov2_model
|
| 13 |
+
from .depth import BNHead, DepthEncoderDecoder, DPTHead
|
| 14 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Weights(Enum):
|
| 18 |
+
NYU = "NYU"
|
| 19 |
+
KITTI = "KITTI"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
|
| 23 |
+
if not pretrained: # Default
|
| 24 |
+
return (0.001, 10.0)
|
| 25 |
+
|
| 26 |
+
# Pretrained, set according to the training dataset for the provided weights
|
| 27 |
+
if weights == Weights.KITTI:
|
| 28 |
+
return (0.001, 80.0)
|
| 29 |
+
|
| 30 |
+
if weights == Weights.NYU:
|
| 31 |
+
return (0.001, 10.0)
|
| 32 |
+
|
| 33 |
+
return (0.001, 10.0)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _make_dinov2_linear_depth_head(
|
| 37 |
+
*,
|
| 38 |
+
embed_dim: int,
|
| 39 |
+
layers: int,
|
| 40 |
+
min_depth: float,
|
| 41 |
+
max_depth: float,
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
if layers not in (1, 4):
|
| 45 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
| 46 |
+
|
| 47 |
+
if layers == 1:
|
| 48 |
+
in_index = [0]
|
| 49 |
+
else:
|
| 50 |
+
assert layers == 4
|
| 51 |
+
in_index = [0, 1, 2, 3]
|
| 52 |
+
|
| 53 |
+
return BNHead(
|
| 54 |
+
classify=True,
|
| 55 |
+
n_bins=256,
|
| 56 |
+
bins_strategy="UD",
|
| 57 |
+
norm_strategy="linear",
|
| 58 |
+
upsample=4,
|
| 59 |
+
in_channels=[embed_dim] * len(in_index),
|
| 60 |
+
in_index=in_index,
|
| 61 |
+
input_transform="resize_concat",
|
| 62 |
+
channels=embed_dim * len(in_index) * 2,
|
| 63 |
+
align_corners=False,
|
| 64 |
+
min_depth=0.001,
|
| 65 |
+
max_depth=80,
|
| 66 |
+
loss_decode=(),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _make_dinov2_linear_depther(
|
| 71 |
+
*,
|
| 72 |
+
arch_name: str = "vit_large",
|
| 73 |
+
layers: int = 4,
|
| 74 |
+
pretrained: bool = True,
|
| 75 |
+
weights: Union[Weights, str] = Weights.NYU,
|
| 76 |
+
depth_range: Optional[Tuple[float, float]] = None,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
if layers not in (1, 4):
|
| 80 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
| 81 |
+
if isinstance(weights, str):
|
| 82 |
+
try:
|
| 83 |
+
weights = Weights[weights]
|
| 84 |
+
except KeyError:
|
| 85 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 86 |
+
|
| 87 |
+
if depth_range is None:
|
| 88 |
+
depth_range = _get_depth_range(pretrained, weights)
|
| 89 |
+
min_depth, max_depth = depth_range
|
| 90 |
+
|
| 91 |
+
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
| 92 |
+
|
| 93 |
+
embed_dim = backbone.embed_dim
|
| 94 |
+
patch_size = backbone.patch_size
|
| 95 |
+
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 96 |
+
linear_depth_head = _make_dinov2_linear_depth_head(
|
| 97 |
+
embed_dim=embed_dim,
|
| 98 |
+
layers=layers,
|
| 99 |
+
min_depth=min_depth,
|
| 100 |
+
max_depth=max_depth,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
layer_count = {
|
| 104 |
+
"vit_small": 12,
|
| 105 |
+
"vit_base": 12,
|
| 106 |
+
"vit_large": 24,
|
| 107 |
+
"vit_giant2": 40,
|
| 108 |
+
}[arch_name]
|
| 109 |
+
|
| 110 |
+
if layers == 4:
|
| 111 |
+
out_index = {
|
| 112 |
+
"vit_small": [2, 5, 8, 11],
|
| 113 |
+
"vit_base": [2, 5, 8, 11],
|
| 114 |
+
"vit_large": [4, 11, 17, 23],
|
| 115 |
+
"vit_giant2": [9, 19, 29, 39],
|
| 116 |
+
}[arch_name]
|
| 117 |
+
else:
|
| 118 |
+
assert layers == 1
|
| 119 |
+
out_index = [layer_count - 1]
|
| 120 |
+
|
| 121 |
+
model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
|
| 122 |
+
model.backbone.forward = partial(
|
| 123 |
+
backbone.get_intermediate_layers,
|
| 124 |
+
n=out_index,
|
| 125 |
+
reshape=True,
|
| 126 |
+
return_class_token=True,
|
| 127 |
+
norm=False,
|
| 128 |
+
)
|
| 129 |
+
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
|
| 130 |
+
|
| 131 |
+
if pretrained:
|
| 132 |
+
layers_str = str(layers) if layers == 4 else ""
|
| 133 |
+
weights_str = weights.value.lower()
|
| 134 |
+
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
|
| 135 |
+
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 136 |
+
if "state_dict" in checkpoint:
|
| 137 |
+
state_dict = checkpoint["state_dict"]
|
| 138 |
+
model.load_state_dict(state_dict, strict=False)
|
| 139 |
+
|
| 140 |
+
return model
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 144 |
+
return _make_dinov2_linear_depther(
|
| 145 |
+
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 150 |
+
return _make_dinov2_linear_depther(
|
| 151 |
+
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 156 |
+
return _make_dinov2_linear_depther(
|
| 157 |
+
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 162 |
+
return _make_dinov2_linear_depther(
|
| 163 |
+
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
|
| 168 |
+
return DPTHead(
|
| 169 |
+
in_channels=[embed_dim] * 4,
|
| 170 |
+
channels=256,
|
| 171 |
+
embed_dims=embed_dim,
|
| 172 |
+
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
|
| 173 |
+
readout_type="project",
|
| 174 |
+
min_depth=min_depth,
|
| 175 |
+
max_depth=max_depth,
|
| 176 |
+
loss_decode=(),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _make_dinov2_dpt_depther(
|
| 181 |
+
*,
|
| 182 |
+
arch_name: str = "vit_large",
|
| 183 |
+
pretrained: bool = True,
|
| 184 |
+
weights: Union[Weights, str] = Weights.NYU,
|
| 185 |
+
depth_range: Optional[Tuple[float, float]] = None,
|
| 186 |
+
**kwargs,
|
| 187 |
+
):
|
| 188 |
+
if isinstance(weights, str):
|
| 189 |
+
try:
|
| 190 |
+
weights = Weights[weights]
|
| 191 |
+
except KeyError:
|
| 192 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 193 |
+
|
| 194 |
+
if depth_range is None:
|
| 195 |
+
depth_range = _get_depth_range(pretrained, weights)
|
| 196 |
+
min_depth, max_depth = depth_range
|
| 197 |
+
|
| 198 |
+
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
| 199 |
+
|
| 200 |
+
model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
|
| 201 |
+
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
|
| 202 |
+
|
| 203 |
+
out_index = {
|
| 204 |
+
"vit_small": [2, 5, 8, 11],
|
| 205 |
+
"vit_base": [2, 5, 8, 11],
|
| 206 |
+
"vit_large": [4, 11, 17, 23],
|
| 207 |
+
"vit_giant2": [9, 19, 29, 39],
|
| 208 |
+
}[arch_name]
|
| 209 |
+
|
| 210 |
+
model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
|
| 211 |
+
model.backbone.forward = partial(
|
| 212 |
+
backbone.get_intermediate_layers,
|
| 213 |
+
n=out_index,
|
| 214 |
+
reshape=True,
|
| 215 |
+
return_class_token=True,
|
| 216 |
+
norm=False,
|
| 217 |
+
)
|
| 218 |
+
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
|
| 219 |
+
|
| 220 |
+
if pretrained:
|
| 221 |
+
weights_str = weights.value.lower()
|
| 222 |
+
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
|
| 223 |
+
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 224 |
+
if "state_dict" in checkpoint:
|
| 225 |
+
state_dict = checkpoint["state_dict"]
|
| 226 |
+
model.load_state_dict(state_dict, strict=False)
|
| 227 |
+
|
| 228 |
+
return model
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 232 |
+
return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 236 |
+
return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 240 |
+
return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
| 244 |
+
return _make_dinov2_dpt_depther(
|
| 245 |
+
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
| 246 |
+
)
|