Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +24 -0
- LHM/__init__.py +15 -0
- LHM/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/__pycache__/launch.cpython-310.pyc +0 -0
- LHM/datasets/__init__.py +16 -0
- LHM/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/datasets/__pycache__/cam_utils.cpython-310.pyc +0 -0
- LHM/datasets/__pycache__/mixer.cpython-310.pyc +0 -0
- LHM/datasets/base.py +70 -0
- LHM/datasets/cam_utils.py +205 -0
- LHM/datasets/mixer.py +75 -0
- LHM/launch.py +37 -0
- LHM/losses/__init__.py +20 -0
- LHM/losses/ball_loss.py +54 -0
- LHM/losses/offset_loss.py +52 -0
- LHM/losses/perceptual.py +70 -0
- LHM/losses/pixelwise.py +58 -0
- LHM/losses/tvloss.py +55 -0
- LHM/models/ESRGANer_utils.py +482 -0
- LHM/models/__init__.py +24 -0
- LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc +0 -0
- LHM/models/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/__pycache__/arcface_utils.cpython-310.pyc +0 -0
- LHM/models/__pycache__/embedder.cpython-310.pyc +0 -0
- LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc +0 -0
- LHM/models/__pycache__/transformer.cpython-310.pyc +0 -0
- LHM/models/__pycache__/transformer_dit.cpython-310.pyc +0 -0
- LHM/models/__pycache__/utils.cpython-310.pyc +0 -0
- LHM/models/arcface_utils.py +360 -0
- LHM/models/block.py +124 -0
- LHM/models/discriminator.py +120 -0
- LHM/models/embedder.py +37 -0
- LHM/models/encoders/__init__.py +15 -0
- LHM/models/encoders/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc +0 -0
- LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc +0 -0
- LHM/models/encoders/dino_wrapper.py +68 -0
- LHM/models/encoders/dinov2/__init__.py +15 -0
- LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__init__.py +4 -0
- LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
- LHM/models/encoders/dinov2/hub/backbones.py +166 -0
- LHM/models/encoders/dinov2/hub/classifiers.py +268 -0
- LHM/models/encoders/dinov2/hub/depth/__init__.py +7 -0
- LHM/models/encoders/dinov2/hub/depth/decode_heads.py +747 -0
- LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py +351 -0
- LHM/models/encoders/dinov2/hub/depth/ops.py +28 -0
- LHM/models/encoders/dinov2/hub/depthers.py +246 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,27 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            assets/characters_images/000001.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            assets/videos/scene_000000/bkgd_video.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            assets/videos/scene_000000/smplx_video.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            diffsynth/models/__pycache__/sd3_text_encoder.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            diffsynth/models/__pycache__/sd_unet.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            diffsynth/models/__pycache__/sdxl_unet.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            diffsynth/models/__pycache__/sdxl_unet.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            diffsynth/models/__pycache__/svd_unet.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            diffsynth/models/__pycache__/svd_unet.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 47 | 
            +
            diffsynth/models/__pycache__/svd_unet.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
         | 
| 48 | 
            +
            diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
         | 
| 49 | 
            +
            diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
         | 
| 50 | 
            +
            engine/pose_estimation/third-party/ViTPose/figures/Throughput.png filter=lfs diff=lfs merge=lfs -text
         | 
| 51 | 
            +
            engine/pose_estimation/third-party/ViTPose/mmpose/.mim/demo/resources/demo.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 52 | 
            +
            engine/pose_estimation/third-party/ViTPose/mmpose/.mim/demo/resources/demo_coco.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 53 | 
            +
            pretrained_models/dense_sample_points/1_40000.ply filter=lfs diff=lfs merge=lfs -text
         | 
| 54 | 
            +
            pretrained_models/dense_sample_points/1_60000.ply filter=lfs diff=lfs merge=lfs -text
         | 
| 55 | 
            +
            pretrained_models/dense_sample_points/1_80000.ply filter=lfs diff=lfs merge=lfs -text
         | 
| 56 | 
            +
            pretrained_models/gagatracker/vgghead/vgg_heads_l.trcd filter=lfs diff=lfs merge=lfs -text
         | 
| 57 | 
            +
            pretrained_models/huggingface/models--3DAIGC--LHM-1B-HF/blobs/59dc25167d1d72d57fb068445b96e2343ab550b649e9999765200502d03171b9 filter=lfs diff=lfs merge=lfs -text
         | 
| 58 | 
            +
            pretrained_models/human_model_files/smplx/smplx_uv/smplx_uv.png filter=lfs diff=lfs merge=lfs -text
         | 
| 59 | 
            +
            pretrained_models/sapiens/pretrained/checkpoints/sapiens_1b/sapiens_1b_epoch_173_torchscript.pt2 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LHM/__init__.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Empty
         | 
    	
        LHM/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (176 Bytes). View file | 
|  | 
    	
        LHM/__pycache__/launch.cpython-310.pyc
    ADDED
    
    | Binary file (723 Bytes). View file | 
|  | 
    	
        LHM/datasets/__init__.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .mixer import MixerDataset
         | 
    	
        LHM/datasets/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (227 Bytes). View file | 
|  | 
    	
        LHM/datasets/__pycache__/cam_utils.cpython-310.pyc
    ADDED
    
    | Binary file (5.46 kB). View file | 
|  | 
    	
        LHM/datasets/__pycache__/mixer.cpython-310.pyc
    ADDED
    
    | Binary file (2.06 kB). View file | 
|  | 
    	
        LHM/datasets/base.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            # @Organization  : Alibaba XR-Lab
         | 
| 3 | 
            +
            # @Author        : Peihao Li & Lingteng Qiu & Xiaodong Gu & Qi Zuo
         | 
| 4 | 
            +
            # @Email         : [email protected]
         | 
| 5 | 
            +
            # @Time          : 2025-03-10 18:47:56
         | 
| 6 | 
            +
            # @Function      : dataset base
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import json
         | 
| 9 | 
            +
            import pdb
         | 
| 10 | 
            +
            import traceback
         | 
| 11 | 
            +
            from abc import ABC, abstractmethod
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from megfile import smart_exists, smart_open, smart_path_join
         | 
| 16 | 
            +
            from PIL import Image
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class BaseDataset(torch.utils.data.Dataset, ABC):
         | 
| 20 | 
            +
                def __init__(self, root_dirs: str, meta_path: str):
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    self.root_dirs = root_dirs
         | 
| 23 | 
            +
                    self.uids = self._load_uids(meta_path)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __len__(self):
         | 
| 26 | 
            +
                    return len(self.uids)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @abstractmethod
         | 
| 29 | 
            +
                def inner_get_item(self, idx):
         | 
| 30 | 
            +
                    pass
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __getitem__(self, idx):
         | 
| 33 | 
            +
                    try:
         | 
| 34 | 
            +
                        return self.inner_get_item(idx)
         | 
| 35 | 
            +
                    except Exception as e:
         | 
| 36 | 
            +
                        traceback.print_exc()
         | 
| 37 | 
            +
                        print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
         | 
| 38 | 
            +
                        # raise e
         | 
| 39 | 
            +
                        return self.__getitem__((idx + 1) % self.__len__())
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @staticmethod
         | 
| 42 | 
            +
                def _load_uids(meta_path: str):
         | 
| 43 | 
            +
                    # meta_path is a json file
         | 
| 44 | 
            +
                    if meta_path == None:
         | 
| 45 | 
            +
                        uids = []
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        with open(meta_path, "r") as f:
         | 
| 48 | 
            +
                            uids = json.load(f)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    return uids
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                @staticmethod
         | 
| 53 | 
            +
                def _load_rgba_image(file_path, bg_color: float = 1.0):
         | 
| 54 | 
            +
                    """Load and blend RGBA image to RGB with certain background, 0-1 scaled"""
         | 
| 55 | 
            +
                    rgba = np.array(Image.open(smart_open(file_path, "rb")))
         | 
| 56 | 
            +
                    rgba = torch.from_numpy(rgba).float() / 255.0
         | 
| 57 | 
            +
                    rgba = rgba.permute(2, 0, 1).unsqueeze(0)
         | 
| 58 | 
            +
                    rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (
         | 
| 59 | 
            +
                        1 - rgba[:, 3:, :, :]
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    # rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
         | 
| 62 | 
            +
                    return rgb
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @staticmethod
         | 
| 65 | 
            +
                def _locate_datadir(root_dirs, uid, locator: str):
         | 
| 66 | 
            +
                    for root_dir in root_dirs:
         | 
| 67 | 
            +
                        datadir = smart_path_join(root_dir, uid, locator)
         | 
| 68 | 
            +
                        if smart_exists(datadir):
         | 
| 69 | 
            +
                            return root_dir
         | 
| 70 | 
            +
                    raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
         | 
    	
        LHM/datasets/cam_utils.py
    ADDED
    
    | @@ -0,0 +1,205 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import math
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            """
         | 
| 20 | 
            +
            R: (N, 3, 3)
         | 
| 21 | 
            +
            T: (N, 3)
         | 
| 22 | 
            +
            E: (N, 4, 4)
         | 
| 23 | 
            +
            vector: (N, 3)
         | 
| 24 | 
            +
            """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Compose the standard form extrinsic matrix from R and T.
         | 
| 30 | 
            +
                Batched I/O.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
         | 
| 33 | 
            +
                return compose_extrinsic_RT(RT)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def compose_extrinsic_RT(RT: torch.Tensor):
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                Compose the standard form extrinsic matrix from RT.
         | 
| 39 | 
            +
                Batched I/O.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                return torch.cat([
         | 
| 42 | 
            +
                    RT,
         | 
| 43 | 
            +
                    torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
         | 
| 44 | 
            +
                    ], dim=1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def decompose_extrinsic_R_T(E: torch.Tensor):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                Decompose the standard extrinsic matrix into R and T.
         | 
| 50 | 
            +
                Batched I/O.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                RT = decompose_extrinsic_RT(E)
         | 
| 53 | 
            +
                return RT[:, :, :3], RT[:, :, 3]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def decompose_extrinsic_RT(E: torch.Tensor):
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                Decompose the standard extrinsic matrix into RT.
         | 
| 59 | 
            +
                Batched I/O.
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                return E[:, :3, :]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
         | 
| 65 | 
            +
                assert normed_dist_to_center is not None
         | 
| 66 | 
            +
                pivotal_pose = compose_extrinsic_RT(poses[:1])
         | 
| 67 | 
            +
                dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
         | 
| 68 | 
            +
                    if normed_dist_to_center == 'auto' else normed_dist_to_center
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                # compute camera norm (new version)
         | 
| 71 | 
            +
                canonical_camera_extrinsics = torch.tensor([[
         | 
| 72 | 
            +
                    [1, 0, 0, 0],
         | 
| 73 | 
            +
                    [0, 0, -1, -dist_to_center],
         | 
| 74 | 
            +
                    [0, 1, 0, 0],
         | 
| 75 | 
            +
                    [0, 0, 0, 1],
         | 
| 76 | 
            +
                ]], dtype=torch.float32)
         | 
| 77 | 
            +
                pivotal_pose_inv = torch.inverse(pivotal_pose)
         | 
| 78 | 
            +
                camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # normalize all views
         | 
| 81 | 
            +
                poses = compose_extrinsic_RT(poses)
         | 
| 82 | 
            +
                poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
         | 
| 83 | 
            +
                poses = decompose_extrinsic_RT(poses)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                if ret_transform:
         | 
| 86 | 
            +
                    return poses, camera_norm_matrix.squeeze(dim=0)
         | 
| 87 | 
            +
                return poses
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
         | 
| 93 | 
            +
                Return batched fx, fy, cx, cy
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
         | 
| 96 | 
            +
                cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
         | 
| 97 | 
            +
                width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
         | 
| 98 | 
            +
                fx, fy = fx / width, fy / height
         | 
| 99 | 
            +
                cx, cy = cx / width, cy / height
         | 
| 100 | 
            +
                return fx, fy, cx, cy
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                RT: (N, 3, 4)
         | 
| 106 | 
            +
                intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
         | 
| 109 | 
            +
                return torch.cat([
         | 
| 110 | 
            +
                    RT.reshape(-1, 12),
         | 
| 111 | 
            +
                    fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
         | 
| 112 | 
            +
                ], dim=-1)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                RT: (N, 3, 4)
         | 
| 118 | 
            +
                intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
         | 
| 119 | 
            +
                """
         | 
| 120 | 
            +
                E = compose_extrinsic_RT(RT)
         | 
| 121 | 
            +
                fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
         | 
| 122 | 
            +
                I = torch.stack([
         | 
| 123 | 
            +
                    torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
         | 
| 124 | 
            +
                    torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
         | 
| 125 | 
            +
                    torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
         | 
| 126 | 
            +
                ], dim=1)
         | 
| 127 | 
            +
                return torch.cat([
         | 
| 128 | 
            +
                    E.reshape(-1, 16),
         | 
| 129 | 
            +
                    I.reshape(-1, 9),
         | 
| 130 | 
            +
                ], dim=-1)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def center_looking_at_camera_pose(
         | 
| 134 | 
            +
                camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None,
         | 
| 135 | 
            +
                device: torch.device = torch.device('cpu'),
         | 
| 136 | 
            +
                ):
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                camera_position: (M, 3)
         | 
| 139 | 
            +
                look_at: (3)
         | 
| 140 | 
            +
                up_world: (3)
         | 
| 141 | 
            +
                return: (M, 3, 4)
         | 
| 142 | 
            +
                """
         | 
| 143 | 
            +
                # by default, looking at the origin and world up is pos-z
         | 
| 144 | 
            +
                if look_at is None:
         | 
| 145 | 
            +
                    look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device)
         | 
| 146 | 
            +
                if up_world is None:
         | 
| 147 | 
            +
                    up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
         | 
| 148 | 
            +
                look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
         | 
| 149 | 
            +
                up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                z_axis = camera_position - look_at
         | 
| 152 | 
            +
                z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
         | 
| 153 | 
            +
                x_axis = torch.cross(up_world, z_axis)
         | 
| 154 | 
            +
                x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
         | 
| 155 | 
            +
                y_axis = torch.cross(z_axis, x_axis)
         | 
| 156 | 
            +
                y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
         | 
| 157 | 
            +
                extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
         | 
| 158 | 
            +
                return extrinsics
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')):
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                n_views: number of surrounding views
         | 
| 164 | 
            +
                radius: camera dist to center
         | 
| 165 | 
            +
                height: height of the camera
         | 
| 166 | 
            +
                return: (M, 3, 4)
         | 
| 167 | 
            +
                """
         | 
| 168 | 
            +
                assert n_views > 0
         | 
| 169 | 
            +
                assert radius > 0
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device)
         | 
| 172 | 
            +
                projected_radius = math.sqrt(radius ** 2 - height ** 2)
         | 
| 173 | 
            +
                x = torch.cos(theta) * projected_radius
         | 
| 174 | 
            +
                y = torch.sin(theta) * projected_radius
         | 
| 175 | 
            +
                z = torch.full((n_views,), height, device=device)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                camera_positions = torch.stack([x, y, z], dim=1)
         | 
| 178 | 
            +
                extrinsics = center_looking_at_camera_pose(camera_positions, device=device)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                return extrinsics
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def create_intrinsics(
         | 
| 184 | 
            +
                f: float,
         | 
| 185 | 
            +
                c: float = None, cx: float = None, cy: float = None,
         | 
| 186 | 
            +
                w: float = 1., h: float = 1.,
         | 
| 187 | 
            +
                dtype: torch.dtype = torch.float32,
         | 
| 188 | 
            +
                device: torch.device = torch.device('cpu'),
         | 
| 189 | 
            +
                ):
         | 
| 190 | 
            +
                """
         | 
| 191 | 
            +
                return: (3, 2)
         | 
| 192 | 
            +
                """
         | 
| 193 | 
            +
                fx = fy = f
         | 
| 194 | 
            +
                if c is not None:
         | 
| 195 | 
            +
                    assert cx is None and cy is None, "c and cx/cy cannot be used together"
         | 
| 196 | 
            +
                    cx = cy = c
         | 
| 197 | 
            +
                else:
         | 
| 198 | 
            +
                    assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided"
         | 
| 199 | 
            +
                fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1.
         | 
| 200 | 
            +
                intrinsics = torch.tensor([
         | 
| 201 | 
            +
                    [fx, fy],
         | 
| 202 | 
            +
                    [cx, cy],
         | 
| 203 | 
            +
                    [w, h],
         | 
| 204 | 
            +
                ], dtype=dtype, device=device)
         | 
| 205 | 
            +
                return intrinsics
         | 
    	
        LHM/datasets/mixer.py
    ADDED
    
    | @@ -0,0 +1,75 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");:
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import math
         | 
| 17 | 
            +
            import pdb
         | 
| 18 | 
            +
            from functools import partial
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            __all__ = ["MixerDataset"]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class MixerDataset(torch.utils.data.Dataset):
         | 
| 26 | 
            +
                """Reference"""
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def __init__(
         | 
| 29 | 
            +
                    self,
         | 
| 30 | 
            +
                    split: str,
         | 
| 31 | 
            +
                    subsets: dict,
         | 
| 32 | 
            +
                    **dataset_kwargs,
         | 
| 33 | 
            +
                ):
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.subsets = [
         | 
| 36 | 
            +
                        self._dataset_fn(subset, split)(
         | 
| 37 | 
            +
                            use_flame=subset["use_flame"],
         | 
| 38 | 
            +
                            src_head_size=subset.get("src_head_size", 448),
         | 
| 39 | 
            +
                            **dataset_kwargs,
         | 
| 40 | 
            +
                        )
         | 
| 41 | 
            +
                        for subset in subsets
         | 
| 42 | 
            +
                    ]
         | 
| 43 | 
            +
                    self.virtual_lens = [
         | 
| 44 | 
            +
                        math.ceil(subset_config["sample_rate"] * len(subset_obj))
         | 
| 45 | 
            +
                        for subset_config, subset_obj in zip(subsets, self.subsets)
         | 
| 46 | 
            +
                    ]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @staticmethod
         | 
| 49 | 
            +
                def _dataset_fn(subset_config: dict, split: str):
         | 
| 50 | 
            +
                    name = subset_config["name"]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    dataset_cls = None
         | 
| 53 | 
            +
                    if name == "video_human":
         | 
| 54 | 
            +
                        from .video_human import VideoHumanDataset
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    else:
         | 
| 57 | 
            +
                        raise NotImplementedError(f"Dataset {name} not implemented")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    return partial(
         | 
| 60 | 
            +
                        dataset_cls,
         | 
| 61 | 
            +
                        root_dirs=subset_config["root_dirs"],
         | 
| 62 | 
            +
                        meta_path=subset_config["meta_path"][split],
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __len__(self):
         | 
| 66 | 
            +
                    return sum(self.virtual_lens)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __getitem__(self, idx):
         | 
| 69 | 
            +
                    subset_idx = 0
         | 
| 70 | 
            +
                    virtual_idx = idx
         | 
| 71 | 
            +
                    while virtual_idx >= self.virtual_lens[subset_idx]:
         | 
| 72 | 
            +
                        virtual_idx -= self.virtual_lens[subset_idx]
         | 
| 73 | 
            +
                        subset_idx += 1
         | 
| 74 | 
            +
                    real_idx = virtual_idx % len(self.subsets[subset_idx])
         | 
| 75 | 
            +
                    return self.subsets[subset_idx][real_idx]
         | 
    	
        LHM/launch.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import argparse
         | 
| 17 | 
            +
            import pdb
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from LHM.runners import REGISTRY_RUNNERS
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def main():
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                parser = argparse.ArgumentParser(description="OpenLRM launcher")
         | 
| 25 | 
            +
                parser.add_argument("runner", type=str, help="Runner to launch")
         | 
| 26 | 
            +
                args, unknown = parser.parse_known_args()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                if args.runner not in REGISTRY_RUNNERS:
         | 
| 29 | 
            +
                    raise ValueError("Runner {} not found".format(args.runner))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                RunnerClass = REGISTRY_RUNNERS[args.runner]
         | 
| 32 | 
            +
                with RunnerClass() as runner:
         | 
| 33 | 
            +
                    runner.run()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            if __name__ == "__main__":
         | 
| 37 | 
            +
                main()
         | 
    	
        LHM/losses/__init__.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .ball_loss import *
         | 
| 17 | 
            +
            from .offset_loss import *
         | 
| 18 | 
            +
            from .perceptual import *
         | 
| 19 | 
            +
            from .pixelwise import *
         | 
| 20 | 
            +
            from .tvloss import *
         | 
    	
        LHM/losses/ball_loss.py
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            # @Organization  : Alibaba XR-Lab
         | 
| 3 | 
            +
            # @Author        : Lingteng Qiu
         | 
| 4 | 
            +
            # @Email         : [email protected]
         | 
| 5 | 
            +
            # @Time          : 2025-03-10 19:08:35
         | 
| 6 | 
            +
            # @Function      : ASAP loss
         | 
| 7 | 
            +
            import pdb
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            __all__ = ["ASAP_Loss", "Heuristic_ASAP_Loss"]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class ASAP_Loss(nn.Module):
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def forward(self, scaling, r=1, **params):
         | 
| 18 | 
            +
                    """where r is the radius of the ball between max-axis and min-axis."""
         | 
| 19 | 
            +
                    raise NotImplementedError(
         | 
| 20 | 
            +
                        "ASAP_Loss is not implemented yet in Inference version"
         | 
| 21 | 
            +
                    )
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class Heuristic_ASAP_Loss(nn.Module):
         | 
| 25 | 
            +
                def __init__(self, group_dict, group_body_mapping):
         | 
| 26 | 
            +
                    super(Heuristic_ASAP_Loss, self).__init__()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    self.group_dict = group_dict  # register weights fro different body parts
         | 
| 29 | 
            +
                    self.group_body_mapping = group_body_mapping  # mapping of body parts to group
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def _heurisitic_loss(self, _ball_loss):
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    _loss = 0.0
         | 
| 34 | 
            +
                    for key in self.group_dict.keys():
         | 
| 35 | 
            +
                        key_weights = self.group_dict[key]
         | 
| 36 | 
            +
                        group_mapping_idx = self.group_body_mapping[key]
         | 
| 37 | 
            +
                        _loss += key_weights * _ball_loss[:, group_mapping_idx].mean()
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    return _loss
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def forward(self, scaling, r=5, **params):
         | 
| 42 | 
            +
                    """where r is the radius of the ball between max-axis and min-axis."""
         | 
| 43 | 
            +
                    "human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    _scale = scaling
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    _scale_min = torch.min(_scale, dim=-1)[0]
         | 
| 48 | 
            +
                    _scale_max = torch.max(_scale, dim=-1)[0]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    scale_ratio = _scale_max / (_scale_min + 1e-6)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    _ball_loss = torch.clamp(scale_ratio, min=r) - r
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    return self._heurisitic_loss(_ball_loss)
         | 
    	
        LHM/losses/offset_loss.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            # @Organization  : Alibaba XR-Lab
         | 
| 3 | 
            +
            # @Author        : Lingteng Qiu
         | 
| 4 | 
            +
            # @Email         : [email protected]
         | 
| 5 | 
            +
            # @Time          : 2025-03-10 19:08:56
         | 
| 6 | 
            +
            # @Function      : ACAP Loss
         | 
| 7 | 
            +
            import pdb
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            __all__ = ["ACAP_Loss", "Heuristic_ACAP_Loss"]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class ACAP_Loss(nn.Module):
         | 
| 17 | 
            +
                """As close as possibel loss"""
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def forward(self, offset, d=0.05625, **params):
         | 
| 20 | 
            +
                    """Empirically, where d is the thresold of distance points leave from 1.8/32 = 0.0562."""
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    return offset_loss.mean()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            class Heuristic_ACAP_Loss(nn.Module):
         | 
| 28 | 
            +
                """As close as possibel loss"""
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def __init__(self, group_dict, group_body_mapping):
         | 
| 31 | 
            +
                    super(Heuristic_ACAP_Loss, self).__init__()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.group_dict = group_dict  # register weights fro different body parts
         | 
| 34 | 
            +
                    self.group_body_mapping = group_body_mapping  # mapping of body parts to group
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def _heurisitic_loss(self, _offset_loss):
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    _loss = 0.0
         | 
| 39 | 
            +
                    for key in self.group_dict.keys():
         | 
| 40 | 
            +
                        key_weights = self.group_dict[key]
         | 
| 41 | 
            +
                        group_mapping_idx = self.group_body_mapping[key]
         | 
| 42 | 
            +
                        _loss += key_weights * _offset_loss[:, group_mapping_idx].mean()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    return _loss
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def forward(self, offset, d=0.05625, **params):
         | 
| 47 | 
            +
                    """Empirically, where d is the thresold of distance points leave from human prior model, 1.8/32 = 0.0562."""
         | 
| 48 | 
            +
                    "human motion or rotation is very different in each body parts, for example, the head is more stable than the leg and hand, so we use heuristic_ball_loss"
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    _offset_loss = torch.clamp(offset.norm(p=2, dim=-1), min=d) - d
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    return self._heurisitic_loss(_offset_loss)
         | 
    	
        LHM/losses/perceptual.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn as nn
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            __all__ = ['LPIPSLoss']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class LPIPSLoss(nn.Module):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Compute LPIPS loss between two images.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __init__(self, device, prefech: bool = False):
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.device = device
         | 
| 30 | 
            +
                    self.cached_models = {}
         | 
| 31 | 
            +
                    if prefech:
         | 
| 32 | 
            +
                        self.prefetch_models()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def _get_model(self, model_name: str):
         | 
| 35 | 
            +
                    if model_name not in self.cached_models:
         | 
| 36 | 
            +
                        import warnings
         | 
| 37 | 
            +
                        with warnings.catch_warnings():
         | 
| 38 | 
            +
                            warnings.filterwarnings('ignore', category=UserWarning)
         | 
| 39 | 
            +
                            import lpips
         | 
| 40 | 
            +
                            _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
         | 
| 41 | 
            +
                        _model = torch.compile(_model)
         | 
| 42 | 
            +
                        self.cached_models[model_name] = _model
         | 
| 43 | 
            +
                    return self.cached_models[model_name]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def prefetch_models(self):
         | 
| 46 | 
            +
                    _model_names = ['alex', 'vgg']
         | 
| 47 | 
            +
                    for model_name in _model_names:
         | 
| 48 | 
            +
                        self._get_model(model_name)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def forward(self, x, y, is_training: bool = True):
         | 
| 51 | 
            +
                    """
         | 
| 52 | 
            +
                    Assume images are 0-1 scaled and channel first.
         | 
| 53 | 
            +
                    
         | 
| 54 | 
            +
                    Args:
         | 
| 55 | 
            +
                        x: [N, M, C, H, W]
         | 
| 56 | 
            +
                        y: [N, M, C, H, W]
         | 
| 57 | 
            +
                        is_training: whether to use VGG or AlexNet.
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    Returns:
         | 
| 60 | 
            +
                        Mean-reduced LPIPS loss across batch.
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    model_name = 'vgg' if is_training else 'alex'
         | 
| 63 | 
            +
                    loss_fn = self._get_model(model_name)
         | 
| 64 | 
            +
                    N, M, C, H, W = x.shape
         | 
| 65 | 
            +
                    x = x.reshape(N*M, C, H, W)
         | 
| 66 | 
            +
                    y = y.reshape(N*M, C, H, W)
         | 
| 67 | 
            +
                    image_loss = loss_fn(x, y, normalize=True).mean(dim=[1, 2, 3])
         | 
| 68 | 
            +
                    batch_loss = image_loss.reshape(N, M).mean(dim=1)
         | 
| 69 | 
            +
                    all_loss = batch_loss.mean()
         | 
| 70 | 
            +
                    return all_loss
         | 
    	
        LHM/losses/pixelwise.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn as nn
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            __all__ = ['PixelLoss']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class PixelLoss(nn.Module):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Pixel-wise loss between two images.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __init__(self, option: str = 'mse'):
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.loss_fn = self._build_from_option(option)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                @staticmethod
         | 
| 32 | 
            +
                def _build_from_option(option: str, reduction: str = 'none'):
         | 
| 33 | 
            +
                    if option == 'mse':
         | 
| 34 | 
            +
                        return nn.MSELoss(reduction=reduction)
         | 
| 35 | 
            +
                    elif option == 'l1':
         | 
| 36 | 
            +
                        return nn.L1Loss(reduction=reduction)
         | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        raise NotImplementedError(f'Unknown pixel loss option: {option}')
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @torch.compile
         | 
| 41 | 
            +
                def forward(self, x, y):
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    Assume images are channel first.
         | 
| 44 | 
            +
                    
         | 
| 45 | 
            +
                    Args:
         | 
| 46 | 
            +
                        x: [N, M, C, H, W]
         | 
| 47 | 
            +
                        y: [N, M, C, H, W]
         | 
| 48 | 
            +
                    
         | 
| 49 | 
            +
                    Returns:
         | 
| 50 | 
            +
                        Mean-reduced pixel loss across batch.
         | 
| 51 | 
            +
                    """
         | 
| 52 | 
            +
                    N, M, C, H, W = x.shape
         | 
| 53 | 
            +
                    x = x.reshape(N*M, C, H, W)
         | 
| 54 | 
            +
                    y = y.reshape(N*M, C, H, W)
         | 
| 55 | 
            +
                    image_loss = self.loss_fn(x, y).mean(dim=[1, 2, 3])
         | 
| 56 | 
            +
                    batch_loss = image_loss.reshape(N, M).mean(dim=1)
         | 
| 57 | 
            +
                    all_loss = batch_loss.mean()
         | 
| 58 | 
            +
                    return all_loss
         | 
    	
        LHM/losses/tvloss.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn as nn
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            __all__ = ['TVLoss']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class TVLoss(nn.Module):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Total variance loss.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __init__(self):
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def numel_excluding_first_dim(self, x):
         | 
| 31 | 
            +
                    return x.numel() // x.shape[0]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                @torch.compile
         | 
| 34 | 
            +
                def forward(self, x):
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    Assume batched and channel first with inner sizes.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    Args:
         | 
| 39 | 
            +
                        x: [N, M, C, H, W]
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    Returns:
         | 
| 42 | 
            +
                        Mean-reduced TV loss with element-level scaling.
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    N, M, C, H, W = x.shape
         | 
| 45 | 
            +
                    x = x.reshape(N*M, C, H, W)
         | 
| 46 | 
            +
                    diff_i = x[..., 1:, :] - x[..., :-1, :]
         | 
| 47 | 
            +
                    diff_j = x[..., :, 1:] - x[..., :, :-1]
         | 
| 48 | 
            +
                    div_i = self.numel_excluding_first_dim(diff_i)
         | 
| 49 | 
            +
                    div_j = self.numel_excluding_first_dim(diff_j)
         | 
| 50 | 
            +
                    tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
         | 
| 51 | 
            +
                    tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
         | 
| 52 | 
            +
                    tv = tv_i + tv_j
         | 
| 53 | 
            +
                    batch_tv = tv.reshape(N, M).mean(dim=1)
         | 
| 54 | 
            +
                    all_tv = batch_tv.mean()
         | 
| 55 | 
            +
                    return all_tv
         | 
    	
        LHM/models/ESRGANer_utils.py
    ADDED
    
    | @@ -0,0 +1,482 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            # @Organization  : Alibaba XR-Lab
         | 
| 3 | 
            +
            # @Author        : Lingteng Qiu
         | 
| 4 | 
            +
            # @Email         : [email protected]
         | 
| 5 | 
            +
            # @Time          : 2025-03-1 17:39:52
         | 
| 6 | 
            +
            # @Function      : Function to improve face quality when training.
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import math
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            import queue
         | 
| 11 | 
            +
            import sys
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            sys.path.append("./")
         | 
| 14 | 
            +
            import threading
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import cv2
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            from basicsr.utils.download_util import load_file_from_url
         | 
| 20 | 
            +
            from torch.nn import functional as F
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 23 | 
            +
            import pdb
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            import torch
         | 
| 26 | 
            +
            from basicsr.archs.rrdbnet_arch import RRDBNet
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def avaliable_device():
         | 
| 30 | 
            +
                if torch.cuda.is_available():
         | 
| 31 | 
            +
                    current_device_id = torch.cuda.current_device()
         | 
| 32 | 
            +
                    device = f"cuda:{current_device_id}"
         | 
| 33 | 
            +
                else:
         | 
| 34 | 
            +
                    device = "cpu"
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                return device
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class RealESRGANer:
         | 
| 40 | 
            +
                """A helper class for upsampling images with RealESRGAN.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                Args:
         | 
| 43 | 
            +
                    scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
         | 
| 44 | 
            +
                    model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
         | 
| 45 | 
            +
                    model (nn.Module): The defined network. Default: None.
         | 
| 46 | 
            +
                    tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
         | 
| 47 | 
            +
                        input images into tiles, and then process each of them. Finally, they will be merged into one image.
         | 
| 48 | 
            +
                        0 denotes for do not use tile. Default: 0.
         | 
| 49 | 
            +
                    tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
         | 
| 50 | 
            +
                    pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
         | 
| 51 | 
            +
                    half (float): Whether to use half precision during inference. Default: False.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def __init__(
         | 
| 55 | 
            +
                    self,
         | 
| 56 | 
            +
                    scale,
         | 
| 57 | 
            +
                    model_path,
         | 
| 58 | 
            +
                    dni_weight=None,
         | 
| 59 | 
            +
                    model=None,
         | 
| 60 | 
            +
                    tile=0,
         | 
| 61 | 
            +
                    tile_pad=10,
         | 
| 62 | 
            +
                    pre_pad=10,
         | 
| 63 | 
            +
                    half=False,
         | 
| 64 | 
            +
                    device=None,
         | 
| 65 | 
            +
                    gpu_id=None,
         | 
| 66 | 
            +
                ):
         | 
| 67 | 
            +
                    self.scale = scale
         | 
| 68 | 
            +
                    self.tile_size = tile
         | 
| 69 | 
            +
                    self.tile_pad = tile_pad
         | 
| 70 | 
            +
                    self.pre_pad = pre_pad
         | 
| 71 | 
            +
                    self.mod_scale = None
         | 
| 72 | 
            +
                    self.half = half
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # initialize model
         | 
| 75 | 
            +
                    if gpu_id:
         | 
| 76 | 
            +
                        self.device = (
         | 
| 77 | 
            +
                            torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
         | 
| 78 | 
            +
                            if device is None
         | 
| 79 | 
            +
                            else device
         | 
| 80 | 
            +
                        )
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        self.device = (
         | 
| 83 | 
            +
                            torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 84 | 
            +
                            if device is None
         | 
| 85 | 
            +
                            else device
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    if isinstance(model_path, list):
         | 
| 89 | 
            +
                        # dni
         | 
| 90 | 
            +
                        assert len(model_path) == len(
         | 
| 91 | 
            +
                            dni_weight
         | 
| 92 | 
            +
                        ), "model_path and dni_weight should have the save length."
         | 
| 93 | 
            +
                        loadnet = self.dni(model_path[0], model_path[1], dni_weight)
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        # if the model_path starts with https, it will first download models to the folder: weights
         | 
| 96 | 
            +
                        if model_path.startswith("https://"):
         | 
| 97 | 
            +
                            model_path = load_file_from_url(
         | 
| 98 | 
            +
                                url=model_path,
         | 
| 99 | 
            +
                                model_dir=os.path.join(ROOT_DIR, "weights"),
         | 
| 100 | 
            +
                                progress=True,
         | 
| 101 | 
            +
                                file_name=None,
         | 
| 102 | 
            +
                            )
         | 
| 103 | 
            +
                        loadnet = torch.load(model_path, map_location=torch.device("cpu"))
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    # prefer to use params_ema
         | 
| 106 | 
            +
                    if "params_ema" in loadnet:
         | 
| 107 | 
            +
                        keyname = "params_ema"
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        keyname = "params"
         | 
| 110 | 
            +
                    model.load_state_dict(loadnet[keyname], strict=True)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    model.eval()
         | 
| 113 | 
            +
                    self.model = model.to(self.device)
         | 
| 114 | 
            +
                    if self.half:
         | 
| 115 | 
            +
                        self.model = self.model.half()
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
         | 
| 118 | 
            +
                    """Deep network interpolation.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
         | 
| 121 | 
            +
                    """
         | 
| 122 | 
            +
                    net_a = torch.load(net_a, map_location=torch.device(loc))
         | 
| 123 | 
            +
                    net_b = torch.load(net_b, map_location=torch.device(loc))
         | 
| 124 | 
            +
                    for k, v_a in net_a[key].items():
         | 
| 125 | 
            +
                        net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
         | 
| 126 | 
            +
                    return net_a
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def pre_process(self, img):
         | 
| 129 | 
            +
                    """Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
         | 
| 130 | 
            +
                    img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
         | 
| 131 | 
            +
                    self.img = img.unsqueeze(0).to(self.device)
         | 
| 132 | 
            +
                    if self.half:
         | 
| 133 | 
            +
                        self.img = self.img.half()
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # pre_pad
         | 
| 136 | 
            +
                    if self.pre_pad != 0:
         | 
| 137 | 
            +
                        self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
         | 
| 138 | 
            +
                    # mod pad for divisible borders
         | 
| 139 | 
            +
                    if self.scale == 2:
         | 
| 140 | 
            +
                        self.mod_scale = 2
         | 
| 141 | 
            +
                    elif self.scale == 1:
         | 
| 142 | 
            +
                        self.mod_scale = 4
         | 
| 143 | 
            +
                    if self.mod_scale is not None:
         | 
| 144 | 
            +
                        self.mod_pad_h, self.mod_pad_w = 0, 0
         | 
| 145 | 
            +
                        _, _, h, w = self.img.size()
         | 
| 146 | 
            +
                        if h % self.mod_scale != 0:
         | 
| 147 | 
            +
                            self.mod_pad_h = self.mod_scale - h % self.mod_scale
         | 
| 148 | 
            +
                        if w % self.mod_scale != 0:
         | 
| 149 | 
            +
                            self.mod_pad_w = self.mod_scale - w % self.mod_scale
         | 
| 150 | 
            +
                        self.img = F.pad(
         | 
| 151 | 
            +
                            self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def process(self):
         | 
| 155 | 
            +
                    # model inference
         | 
| 156 | 
            +
                    self.output = self.model(self.img)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def tile_process(self):
         | 
| 159 | 
            +
                    """It will first crop input images to tiles, and then process each tile.
         | 
| 160 | 
            +
                    Finally, all the processed tiles are merged into one images.
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    Modified from: https://github.com/ata4/esrgan-launcher
         | 
| 163 | 
            +
                    """
         | 
| 164 | 
            +
                    batch, channel, height, width = self.img.shape
         | 
| 165 | 
            +
                    output_height = height * self.scale
         | 
| 166 | 
            +
                    output_width = width * self.scale
         | 
| 167 | 
            +
                    output_shape = (batch, channel, output_height, output_width)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    # start with black image
         | 
| 170 | 
            +
                    self.output = self.img.new_zeros(output_shape)
         | 
| 171 | 
            +
                    tiles_x = math.ceil(width / self.tile_size)
         | 
| 172 | 
            +
                    tiles_y = math.ceil(height / self.tile_size)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # loop over all tiles
         | 
| 175 | 
            +
                    for y in range(tiles_y):
         | 
| 176 | 
            +
                        for x in range(tiles_x):
         | 
| 177 | 
            +
                            # extract tile from input image
         | 
| 178 | 
            +
                            ofs_x = x * self.tile_size
         | 
| 179 | 
            +
                            ofs_y = y * self.tile_size
         | 
| 180 | 
            +
                            # input tile area on total image
         | 
| 181 | 
            +
                            input_start_x = ofs_x
         | 
| 182 | 
            +
                            input_end_x = min(ofs_x + self.tile_size, width)
         | 
| 183 | 
            +
                            input_start_y = ofs_y
         | 
| 184 | 
            +
                            input_end_y = min(ofs_y + self.tile_size, height)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                            # input tile area on total image with padding
         | 
| 187 | 
            +
                            input_start_x_pad = max(input_start_x - self.tile_pad, 0)
         | 
| 188 | 
            +
                            input_end_x_pad = min(input_end_x + self.tile_pad, width)
         | 
| 189 | 
            +
                            input_start_y_pad = max(input_start_y - self.tile_pad, 0)
         | 
| 190 | 
            +
                            input_end_y_pad = min(input_end_y + self.tile_pad, height)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                            # input tile dimensions
         | 
| 193 | 
            +
                            input_tile_width = input_end_x - input_start_x
         | 
| 194 | 
            +
                            input_tile_height = input_end_y - input_start_y
         | 
| 195 | 
            +
                            tile_idx = y * tiles_x + x + 1
         | 
| 196 | 
            +
                            input_tile = self.img[
         | 
| 197 | 
            +
                                :,
         | 
| 198 | 
            +
                                :,
         | 
| 199 | 
            +
                                input_start_y_pad:input_end_y_pad,
         | 
| 200 | 
            +
                                input_start_x_pad:input_end_x_pad,
         | 
| 201 | 
            +
                            ]
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                            # upscale tile
         | 
| 204 | 
            +
                            try:
         | 
| 205 | 
            +
                                with torch.no_grad():
         | 
| 206 | 
            +
                                    output_tile = self.model(input_tile)
         | 
| 207 | 
            +
                            except RuntimeError as error:
         | 
| 208 | 
            +
                                print("Error", error)
         | 
| 209 | 
            +
                            print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                            # output tile area on total image
         | 
| 212 | 
            +
                            output_start_x = input_start_x * self.scale
         | 
| 213 | 
            +
                            output_end_x = input_end_x * self.scale
         | 
| 214 | 
            +
                            output_start_y = input_start_y * self.scale
         | 
| 215 | 
            +
                            output_end_y = input_end_y * self.scale
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                            # output tile area without padding
         | 
| 218 | 
            +
                            output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
         | 
| 219 | 
            +
                            output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
         | 
| 220 | 
            +
                            output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
         | 
| 221 | 
            +
                            output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                            # put tile into output image
         | 
| 224 | 
            +
                            self.output[
         | 
| 225 | 
            +
                                :, :, output_start_y:output_end_y, output_start_x:output_end_x
         | 
| 226 | 
            +
                            ] = output_tile[
         | 
| 227 | 
            +
                                :,
         | 
| 228 | 
            +
                                :,
         | 
| 229 | 
            +
                                output_start_y_tile:output_end_y_tile,
         | 
| 230 | 
            +
                                output_start_x_tile:output_end_x_tile,
         | 
| 231 | 
            +
                            ]
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def post_process(self):
         | 
| 234 | 
            +
                    # remove extra pad
         | 
| 235 | 
            +
                    if self.mod_scale is not None:
         | 
| 236 | 
            +
                        _, _, h, w = self.output.size()
         | 
| 237 | 
            +
                        self.output = self.output[
         | 
| 238 | 
            +
                            :,
         | 
| 239 | 
            +
                            :,
         | 
| 240 | 
            +
                            0 : h - self.mod_pad_h * self.scale,
         | 
| 241 | 
            +
                            0 : w - self.mod_pad_w * self.scale,
         | 
| 242 | 
            +
                        ]
         | 
| 243 | 
            +
                    # remove prepad
         | 
| 244 | 
            +
                    if self.pre_pad != 0:
         | 
| 245 | 
            +
                        _, _, h, w = self.output.size()
         | 
| 246 | 
            +
                        self.output = self.output[
         | 
| 247 | 
            +
                            :,
         | 
| 248 | 
            +
                            :,
         | 
| 249 | 
            +
                            0 : h - self.pre_pad * self.scale,
         | 
| 250 | 
            +
                            0 : w - self.pre_pad * self.scale,
         | 
| 251 | 
            +
                        ]
         | 
| 252 | 
            +
                    return self.output
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                @torch.no_grad()
         | 
| 255 | 
            +
                def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
         | 
| 256 | 
            +
                    h_input, w_input = img.shape[0:2]
         | 
| 257 | 
            +
                    # img: numpy
         | 
| 258 | 
            +
                    img = img.astype(np.float32)
         | 
| 259 | 
            +
                    if np.max(img) > 256:  # 16-bit image
         | 
| 260 | 
            +
                        max_range = 65535
         | 
| 261 | 
            +
                        print("\tInput is a 16-bit image")
         | 
| 262 | 
            +
                    else:
         | 
| 263 | 
            +
                        max_range = 255
         | 
| 264 | 
            +
                    img = img / max_range
         | 
| 265 | 
            +
                    if len(img.shape) == 2:  # gray image
         | 
| 266 | 
            +
                        img_mode = "L"
         | 
| 267 | 
            +
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
         | 
| 268 | 
            +
                    elif img.shape[2] == 4:  # RGBA image with alpha channel
         | 
| 269 | 
            +
                        img_mode = "RGBA"
         | 
| 270 | 
            +
                        alpha = img[:, :, 3]
         | 
| 271 | 
            +
                        img = img[:, :, 0:3]
         | 
| 272 | 
            +
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
         | 
| 273 | 
            +
                        if alpha_upsampler == "realesrgan":
         | 
| 274 | 
            +
                            alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
         | 
| 275 | 
            +
                    else:
         | 
| 276 | 
            +
                        img_mode = "RGB"
         | 
| 277 | 
            +
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # ------------------- process image (without the alpha channel) ------------------- #
         | 
| 280 | 
            +
                    self.pre_process(img)
         | 
| 281 | 
            +
                    if self.tile_size > 0:
         | 
| 282 | 
            +
                        self.tile_process()
         | 
| 283 | 
            +
                    else:
         | 
| 284 | 
            +
                        self.process()
         | 
| 285 | 
            +
                    output_img = self.post_process()
         | 
| 286 | 
            +
                    output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
         | 
| 287 | 
            +
                    output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
         | 
| 288 | 
            +
                    if img_mode == "L":
         | 
| 289 | 
            +
                        output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # ------------------- process the alpha channel if necessary ------------------- #
         | 
| 292 | 
            +
                    if img_mode == "RGBA":
         | 
| 293 | 
            +
                        if alpha_upsampler == "realesrgan":
         | 
| 294 | 
            +
                            self.pre_process(alpha)
         | 
| 295 | 
            +
                            if self.tile_size > 0:
         | 
| 296 | 
            +
                                self.tile_process()
         | 
| 297 | 
            +
                            else:
         | 
| 298 | 
            +
                                self.process()
         | 
| 299 | 
            +
                            output_alpha = self.post_process()
         | 
| 300 | 
            +
                            output_alpha = (
         | 
| 301 | 
            +
                                output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
         | 
| 302 | 
            +
                            )
         | 
| 303 | 
            +
                            output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
         | 
| 304 | 
            +
                            output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
         | 
| 305 | 
            +
                        else:  # use the cv2 resize for alpha channel
         | 
| 306 | 
            +
                            h, w = alpha.shape[0:2]
         | 
| 307 | 
            +
                            output_alpha = cv2.resize(
         | 
| 308 | 
            +
                                alpha,
         | 
| 309 | 
            +
                                (w * self.scale, h * self.scale),
         | 
| 310 | 
            +
                                interpolation=cv2.INTER_LINEAR,
         | 
| 311 | 
            +
                            )
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                        # merge the alpha channel
         | 
| 314 | 
            +
                        output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
         | 
| 315 | 
            +
                        output_img[:, :, 3] = output_alpha
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    # ------------------------------ return ------------------------------ #
         | 
| 318 | 
            +
                    if max_range == 65535:  # 16-bit image
         | 
| 319 | 
            +
                        output = (output_img * 65535.0).round().astype(np.uint16)
         | 
| 320 | 
            +
                    else:
         | 
| 321 | 
            +
                        output = (output_img * 255.0).round().astype(np.uint8)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if outscale is not None and outscale != float(self.scale):
         | 
| 324 | 
            +
                        output = cv2.resize(
         | 
| 325 | 
            +
                            output,
         | 
| 326 | 
            +
                            (
         | 
| 327 | 
            +
                                int(w_input * outscale),
         | 
| 328 | 
            +
                                int(h_input * outscale),
         | 
| 329 | 
            +
                            ),
         | 
| 330 | 
            +
                            interpolation=cv2.INTER_LANCZOS4,
         | 
| 331 | 
            +
                        )
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    return output, img_mode
         | 
| 334 | 
            +
             | 
| 335 | 
            +
             | 
| 336 | 
            +
            class PrefetchReader(threading.Thread):
         | 
| 337 | 
            +
                """Prefetch images.
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                Args:
         | 
| 340 | 
            +
                    img_list (list[str]): A image list of image paths to be read.
         | 
| 341 | 
            +
                    num_prefetch_queue (int): Number of prefetch queue.
         | 
| 342 | 
            +
                """
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def __init__(self, img_list, num_prefetch_queue):
         | 
| 345 | 
            +
                    super().__init__()
         | 
| 346 | 
            +
                    self.que = queue.Queue(num_prefetch_queue)
         | 
| 347 | 
            +
                    self.img_list = img_list
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                def run(self):
         | 
| 350 | 
            +
                    for img_path in self.img_list:
         | 
| 351 | 
            +
                        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
         | 
| 352 | 
            +
                        self.que.put(img)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    self.que.put(None)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                def __next__(self):
         | 
| 357 | 
            +
                    next_item = self.que.get()
         | 
| 358 | 
            +
                    if next_item is None:
         | 
| 359 | 
            +
                        raise StopIteration
         | 
| 360 | 
            +
                    return next_item
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                def __iter__(self):
         | 
| 363 | 
            +
                    return self
         | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
            class IOConsumer(threading.Thread):
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                def __init__(self, opt, que, qid):
         | 
| 369 | 
            +
                    super().__init__()
         | 
| 370 | 
            +
                    self._queue = que
         | 
| 371 | 
            +
                    self.qid = qid
         | 
| 372 | 
            +
                    self.opt = opt
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                def run(self):
         | 
| 375 | 
            +
                    while True:
         | 
| 376 | 
            +
                        msg = self._queue.get()
         | 
| 377 | 
            +
                        if isinstance(msg, str) and msg == "quit":
         | 
| 378 | 
            +
                            break
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                        output = msg["output"]
         | 
| 381 | 
            +
                        save_path = msg["save_path"]
         | 
| 382 | 
            +
                        cv2.imwrite(save_path, output)
         | 
| 383 | 
            +
                    print(f"IO worker {self.qid} is done.")
         | 
| 384 | 
            +
             | 
| 385 | 
            +
             | 
| 386 | 
            +
            class ESRGANEasyModel:
         | 
| 387 | 
            +
                def __init__(
         | 
| 388 | 
            +
                    self, model_path="./pretrained_models/RealESRGAN_x4plus.pth", face_enhance=True
         | 
| 389 | 
            +
                ):
         | 
| 390 | 
            +
                    model = RRDBNet(
         | 
| 391 | 
            +
                        num_in_ch=3,
         | 
| 392 | 
            +
                        num_out_ch=3,
         | 
| 393 | 
            +
                        num_feat=64,
         | 
| 394 | 
            +
                        num_block=23,
         | 
| 395 | 
            +
                        num_grow_ch=32,
         | 
| 396 | 
            +
                        scale=4,
         | 
| 397 | 
            +
                    )
         | 
| 398 | 
            +
                    self.net_scale = 4
         | 
| 399 | 
            +
                    file_url = [
         | 
| 400 | 
            +
                        "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
         | 
| 401 | 
            +
                    ]
         | 
| 402 | 
            +
                    if model_path is None:
         | 
| 403 | 
            +
                        model_path = os.path.join("weights", args.model_name + ".pth")
         | 
| 404 | 
            +
                        if not os.path.isfile(model_path):
         | 
| 405 | 
            +
                            ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
         | 
| 406 | 
            +
                            for url in file_url:
         | 
| 407 | 
            +
                                # model_path will be updated
         | 
| 408 | 
            +
                                model_path = load_file_from_url(
         | 
| 409 | 
            +
                                    url=url,
         | 
| 410 | 
            +
                                    model_dir=os.path.join("./", "pretrained_models"),
         | 
| 411 | 
            +
                                    progress=True,
         | 
| 412 | 
            +
                                    file_name=None,
         | 
| 413 | 
            +
                                )
         | 
| 414 | 
            +
                    self.face_enhance = face_enhance
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    dni_weight = None
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    self.upsampler = RealESRGANer(
         | 
| 419 | 
            +
                        scale=self.net_scale,
         | 
| 420 | 
            +
                        model_path=model_path,
         | 
| 421 | 
            +
                        dni_weight=dni_weight,
         | 
| 422 | 
            +
                        model=model,
         | 
| 423 | 
            +
                        tile=0,
         | 
| 424 | 
            +
                        tile_pad=10,
         | 
| 425 | 
            +
                        pre_pad=0,
         | 
| 426 | 
            +
                        half=False,
         | 
| 427 | 
            +
                    )
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.upsampler.model.to(avaliable_device())
         | 
| 430 | 
            +
                    if face_enhance:  # Use GFPGAN for face enhancement
         | 
| 431 | 
            +
                        from gfpgan import GFPGANer
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                        self.face_enhancer = GFPGANer(
         | 
| 434 | 
            +
                            model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
         | 
| 435 | 
            +
                            upscale=4,
         | 
| 436 | 
            +
                            arch="clean",
         | 
| 437 | 
            +
                            channel_multiplier=2,
         | 
| 438 | 
            +
                            bg_upsampler=self.upsampler,
         | 
| 439 | 
            +
                        )
         | 
| 440 | 
            +
                    else:
         | 
| 441 | 
            +
                        self.face_enhancer = None
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                @torch.no_grad()
         | 
| 444 | 
            +
                def __call__(self, img):
         | 
| 445 | 
            +
                    if self.face_enhancer is not None:
         | 
| 446 | 
            +
                        _, _, output = self.face_enhancer.enhance(
         | 
| 447 | 
            +
                            img, has_aligned=False, only_center_face=False, paste_back=True
         | 
| 448 | 
            +
                        )
         | 
| 449 | 
            +
                    else:
         | 
| 450 | 
            +
                        output, _ = self.upsampler.enhance(img, outscale=4)
         | 
| 451 | 
            +
                    return output
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                def __repr__(self):
         | 
| 454 | 
            +
                    return f"ESRGANEasyModel:\n {self.upsampler}"
         | 
| 455 | 
            +
             | 
| 456 | 
            +
             | 
| 457 | 
            +
            if __name__ == "__main__":
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                import time
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                model = ESRGANEasyModel(face_enhance=True)
         | 
| 462 | 
            +
                input_img = "./debug/face_debug/gt/head_gt_0.png"
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                img_np = cv2.imread(input_img)
         | 
| 465 | 
            +
                set1 = [
         | 
| 466 | 
            +
                    "./debug/face_debug/gt/head_gt_0.png",
         | 
| 467 | 
            +
                    "./debug/face_debug/gt/head_gt_1.png",
         | 
| 468 | 
            +
                    "./debug/face_debug/gt/head_gt_2.png",
         | 
| 469 | 
            +
                    "./debug/face_debug/gt/head_gt_3.png",
         | 
| 470 | 
            +
                    "./debug/face_debug/gt/head_gt_4.png",
         | 
| 471 | 
            +
                    "./debug/face_debug/gt/head_gt_5.png",
         | 
| 472 | 
            +
                    "./debug/face_debug/gt/head_gt_6.png",
         | 
| 473 | 
            +
                    "./debug/face_debug/gt/head_gt_0.png",
         | 
| 474 | 
            +
                ]
         | 
| 475 | 
            +
                img_set1 = [cv2.imread(img_path) for img_path in set1]
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                sr = model(img_set1[0])
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                s0 = time.time()
         | 
| 480 | 
            +
                for img in img_set1:
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    sr = model(img)
         | 
    	
        LHM/models/__init__.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .modeling_human_lrm import (
         | 
| 17 | 
            +
                ModelHumanLRM,
         | 
| 18 | 
            +
                ModelHumanLRMSapdinoBodyHeadSD3_5,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            model_dict = {
         | 
| 22 | 
            +
                "human_lrm": ModelHumanLRM,
         | 
| 23 | 
            +
                "human_lrm_sapdino_bh_sd3_5": ModelHumanLRMSapdinoBodyHeadSD3_5,
         | 
| 24 | 
            +
            }
         | 
    	
        LHM/models/__pycache__/ESRGANer_utils.cpython-310.pyc
    ADDED
    
    | Binary file (11.9 kB). View file | 
|  | 
    	
        LHM/models/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (352 Bytes). View file | 
|  | 
    	
        LHM/models/__pycache__/arcface_utils.cpython-310.pyc
    ADDED
    
    | Binary file (9.73 kB). View file | 
|  | 
    	
        LHM/models/__pycache__/embedder.cpython-310.pyc
    ADDED
    
    | Binary file (1.03 kB). View file | 
|  | 
    	
        LHM/models/__pycache__/modeling_human_lrm.cpython-310.pyc
    ADDED
    
    | Binary file (21.7 kB). View file | 
|  | 
    	
        LHM/models/__pycache__/transformer.cpython-310.pyc
    ADDED
    
    | Binary file (6.89 kB). View file | 
|  | 
    	
        LHM/models/__pycache__/transformer_dit.cpython-310.pyc
    ADDED
    
    | Binary file (15 kB). View file | 
|  | 
    	
        LHM/models/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (1.44 kB). View file | 
|  | 
    	
        LHM/models/arcface_utils.py
    ADDED
    
    | @@ -0,0 +1,360 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            # @Organization  : Alibaba XR-Lab
         | 
| 3 | 
            +
            # @Author        : Lingteng Qiu
         | 
| 4 | 
            +
            # @Email         : [email protected]
         | 
| 5 | 
            +
            # @Time          : 2025-03-10 17:38:29
         | 
| 6 | 
            +
            # @Function      : Arc-Similarity Loss
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            sys.path.append(".")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import pdb
         | 
| 12 | 
            +
            from copy import deepcopy
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from torch.nn.parallel import DataParallel, DistributedDataParallel
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def conv3x3(inplanes, outplanes, stride=1):
         | 
| 21 | 
            +
                """A simple wrapper for 3x3 convolution with padding.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Args:
         | 
| 24 | 
            +
                    inplanes (int): Channel number of inputs.
         | 
| 25 | 
            +
                    outplanes (int): Channel number of outputs.
         | 
| 26 | 
            +
                    stride (int): Stride in convolution. Default: 1.
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                return nn.Conv2d(
         | 
| 29 | 
            +
                    inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
         | 
| 30 | 
            +
                )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class BasicBlock(nn.Module):
         | 
| 34 | 
            +
                """Basic residual block used in the ResNetArcFace architecture.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Args:
         | 
| 37 | 
            +
                    inplanes (int): Channel number of inputs.
         | 
| 38 | 
            +
                    planes (int): Channel number of outputs.
         | 
| 39 | 
            +
                    stride (int): Stride in convolution. Default: 1.
         | 
| 40 | 
            +
                    downsample (nn.Module): The downsample module. Default: None.
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                expansion = 1  # output channel expansion ratio
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def __init__(self, inplanes, planes, stride=1, downsample=None):
         | 
| 46 | 
            +
                    super(BasicBlock, self).__init__()
         | 
| 47 | 
            +
                    self.conv1 = conv3x3(inplanes, planes, stride)
         | 
| 48 | 
            +
                    self.bn1 = nn.BatchNorm2d(planes)
         | 
| 49 | 
            +
                    self.relu = nn.ReLU(inplace=True)
         | 
| 50 | 
            +
                    self.conv2 = conv3x3(planes, planes)
         | 
| 51 | 
            +
                    self.bn2 = nn.BatchNorm2d(planes)
         | 
| 52 | 
            +
                    self.downsample = downsample
         | 
| 53 | 
            +
                    self.stride = stride
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def forward(self, x):
         | 
| 56 | 
            +
                    residual = x
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    out = self.conv1(x)
         | 
| 59 | 
            +
                    out = self.bn1(out)
         | 
| 60 | 
            +
                    out = self.relu(out)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    out = self.conv2(out)
         | 
| 63 | 
            +
                    out = self.bn2(out)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if self.downsample is not None:
         | 
| 66 | 
            +
                        residual = self.downsample(x)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    out += residual
         | 
| 69 | 
            +
                    out = self.relu(out)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    return out
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            class IRBlock(nn.Module):
         | 
| 75 | 
            +
                """Improved residual block (IR Block) used in the ResNetArcFace architecture.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                Args:
         | 
| 78 | 
            +
                    inplanes (int): Channel number of inputs.
         | 
| 79 | 
            +
                    planes (int): Channel number of outputs.
         | 
| 80 | 
            +
                    stride (int): Stride in convolution. Default: 1.
         | 
| 81 | 
            +
                    downsample (nn.Module): The downsample module. Default: None.
         | 
| 82 | 
            +
                    use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                expansion = 1  # output channel expansion ratio
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
         | 
| 88 | 
            +
                    super(IRBlock, self).__init__()
         | 
| 89 | 
            +
                    self.bn0 = nn.BatchNorm2d(inplanes)
         | 
| 90 | 
            +
                    self.conv1 = conv3x3(inplanes, inplanes)
         | 
| 91 | 
            +
                    self.bn1 = nn.BatchNorm2d(inplanes)
         | 
| 92 | 
            +
                    self.prelu = nn.PReLU()
         | 
| 93 | 
            +
                    self.conv2 = conv3x3(inplanes, planes, stride)
         | 
| 94 | 
            +
                    self.bn2 = nn.BatchNorm2d(planes)
         | 
| 95 | 
            +
                    self.downsample = downsample
         | 
| 96 | 
            +
                    self.stride = stride
         | 
| 97 | 
            +
                    self.use_se = use_se
         | 
| 98 | 
            +
                    if self.use_se:
         | 
| 99 | 
            +
                        self.se = SEBlock(planes)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def forward(self, x):
         | 
| 102 | 
            +
                    residual = x
         | 
| 103 | 
            +
                    out = self.bn0(x)
         | 
| 104 | 
            +
                    out = self.conv1(out)
         | 
| 105 | 
            +
                    out = self.bn1(out)
         | 
| 106 | 
            +
                    out = self.prelu(out)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    out = self.conv2(out)
         | 
| 109 | 
            +
                    out = self.bn2(out)
         | 
| 110 | 
            +
                    if self.use_se:
         | 
| 111 | 
            +
                        out = self.se(out)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if self.downsample is not None:
         | 
| 114 | 
            +
                        residual = self.downsample(x)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    out += residual
         | 
| 117 | 
            +
                    out = self.prelu(out)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    return out
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            class Bottleneck(nn.Module):
         | 
| 123 | 
            +
                """Bottleneck block used in the ResNetArcFace architecture.
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                Args:
         | 
| 126 | 
            +
                    inplanes (int): Channel number of inputs.
         | 
| 127 | 
            +
                    planes (int): Channel number of outputs.
         | 
| 128 | 
            +
                    stride (int): Stride in convolution. Default: 1.
         | 
| 129 | 
            +
                    downsample (nn.Module): The downsample module. Default: None.
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                expansion = 4  # output channel expansion ratio
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __init__(self, inplanes, planes, stride=1, downsample=None):
         | 
| 135 | 
            +
                    super(Bottleneck, self).__init__()
         | 
| 136 | 
            +
                    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
         | 
| 137 | 
            +
                    self.bn1 = nn.BatchNorm2d(planes)
         | 
| 138 | 
            +
                    self.conv2 = nn.Conv2d(
         | 
| 139 | 
            +
                        planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    self.bn2 = nn.BatchNorm2d(planes)
         | 
| 142 | 
            +
                    self.conv3 = nn.Conv2d(
         | 
| 143 | 
            +
                        planes, planes * self.expansion, kernel_size=1, bias=False
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    self.bn3 = nn.BatchNorm2d(planes * self.expansion)
         | 
| 146 | 
            +
                    self.relu = nn.ReLU(inplace=True)
         | 
| 147 | 
            +
                    self.downsample = downsample
         | 
| 148 | 
            +
                    self.stride = stride
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def forward(self, x):
         | 
| 151 | 
            +
                    residual = x
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    out = self.conv1(x)
         | 
| 154 | 
            +
                    out = self.bn1(out)
         | 
| 155 | 
            +
                    out = self.relu(out)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    out = self.conv2(out)
         | 
| 158 | 
            +
                    out = self.bn2(out)
         | 
| 159 | 
            +
                    out = self.relu(out)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    out = self.conv3(out)
         | 
| 162 | 
            +
                    out = self.bn3(out)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    if self.downsample is not None:
         | 
| 165 | 
            +
                        residual = self.downsample(x)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    out += residual
         | 
| 168 | 
            +
                    out = self.relu(out)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    return out
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            class SEBlock(nn.Module):
         | 
| 174 | 
            +
                """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                Args:
         | 
| 177 | 
            +
                    channel (int): Channel number of inputs.
         | 
| 178 | 
            +
                    reduction (int): Channel reduction ration. Default: 16.
         | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def __init__(self, channel, reduction=16):
         | 
| 182 | 
            +
                    super(SEBlock, self).__init__()
         | 
| 183 | 
            +
                    self.avg_pool = nn.AdaptiveAvgPool2d(
         | 
| 184 | 
            +
                        1
         | 
| 185 | 
            +
                    )  # pool to 1x1 without spatial information
         | 
| 186 | 
            +
                    self.fc = nn.Sequential(
         | 
| 187 | 
            +
                        nn.Linear(channel, channel // reduction),
         | 
| 188 | 
            +
                        nn.PReLU(),
         | 
| 189 | 
            +
                        nn.Linear(channel // reduction, channel),
         | 
| 190 | 
            +
                        nn.Sigmoid(),
         | 
| 191 | 
            +
                    )
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def forward(self, x):
         | 
| 194 | 
            +
                    b, c, _, _ = x.size()
         | 
| 195 | 
            +
                    y = self.avg_pool(x).view(b, c)
         | 
| 196 | 
            +
                    y = self.fc(y).view(b, c, 1, 1)
         | 
| 197 | 
            +
                    return x * y
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class ResNetArcFace(nn.Module):
         | 
| 201 | 
            +
                """ArcFace with ResNet architectures.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                Args:
         | 
| 206 | 
            +
                    block (str): Block used in the ArcFace architecture.
         | 
| 207 | 
            +
                    layers (tuple(int)): Block numbers in each layer.
         | 
| 208 | 
            +
                    use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
         | 
| 209 | 
            +
                """
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __init__(
         | 
| 212 | 
            +
                    self,
         | 
| 213 | 
            +
                    block="IRBlock",
         | 
| 214 | 
            +
                    layers=[2, 2, 2, 2],
         | 
| 215 | 
            +
                    use_se=False,
         | 
| 216 | 
            +
                    pretrain_model="./pretrained_models/arcface_resnet18.pth",
         | 
| 217 | 
            +
                ):
         | 
| 218 | 
            +
                    if block == "IRBlock":
         | 
| 219 | 
            +
                        block = IRBlock
         | 
| 220 | 
            +
                    self.inplanes = 64
         | 
| 221 | 
            +
                    self.use_se = use_se
         | 
| 222 | 
            +
                    super(ResNetArcFace, self).__init__()
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
         | 
| 225 | 
            +
                    self.bn1 = nn.BatchNorm2d(64)
         | 
| 226 | 
            +
                    self.prelu = nn.PReLU()
         | 
| 227 | 
            +
                    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
         | 
| 228 | 
            +
                    self.layer1 = self._make_layer(block, 64, layers[0])
         | 
| 229 | 
            +
                    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
         | 
| 230 | 
            +
                    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
         | 
| 231 | 
            +
                    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
         | 
| 232 | 
            +
                    self.bn4 = nn.BatchNorm2d(512)
         | 
| 233 | 
            +
                    self.dropout = nn.Dropout()
         | 
| 234 | 
            +
                    self.fc5 = nn.Linear(512 * 8 * 8, 512)
         | 
| 235 | 
            +
                    self.bn5 = nn.BatchNorm1d(512)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    # initialization
         | 
| 238 | 
            +
                    for m in self.modules():
         | 
| 239 | 
            +
                        if isinstance(m, nn.Conv2d):
         | 
| 240 | 
            +
                            nn.init.xavier_normal_(m.weight)
         | 
| 241 | 
            +
                        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
         | 
| 242 | 
            +
                            nn.init.constant_(m.weight, 1)
         | 
| 243 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 244 | 
            +
                        elif isinstance(m, nn.Linear):
         | 
| 245 | 
            +
                            nn.init.xavier_normal_(m.weight)
         | 
| 246 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    if pretrain_model is not None:
         | 
| 249 | 
            +
                        self.load_network(self, pretrain_model, strict=True, param_key=None)
         | 
| 250 | 
            +
                    else:
         | 
| 251 | 
            +
                        raise ValueError("Please specify the pretrain model path.")
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    self.freeze()
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                @staticmethod
         | 
| 256 | 
            +
                def load_network(net, load_path, strict=True, param_key=None):
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    def get_bare_model(net):
         | 
| 259 | 
            +
                        if isinstance(net, (DataParallel, DistributedDataParallel)):
         | 
| 260 | 
            +
                            net = net.module
         | 
| 261 | 
            +
                        return net
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    net = get_bare_model(net)
         | 
| 264 | 
            +
                    load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
         | 
| 265 | 
            +
                    if param_key is not None:
         | 
| 266 | 
            +
                        if param_key not in load_net and "params" in load_net:
         | 
| 267 | 
            +
                            param_key = "params"
         | 
| 268 | 
            +
                        load_net = load_net[param_key]
         | 
| 269 | 
            +
                    # remove unnecessary 'module.'
         | 
| 270 | 
            +
                    for k, v in deepcopy(load_net).items():
         | 
| 271 | 
            +
                        if k.startswith("module."):
         | 
| 272 | 
            +
                            load_net[k[7:]] = v
         | 
| 273 | 
            +
                            load_net.pop(k)
         | 
| 274 | 
            +
                    ret = net.load_state_dict(load_net, strict=strict)
         | 
| 275 | 
            +
                    print(ret)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def _make_layer(self, block, planes, num_blocks, stride=1):
         | 
| 278 | 
            +
                    downsample = None
         | 
| 279 | 
            +
                    if stride != 1 or self.inplanes != planes * block.expansion:
         | 
| 280 | 
            +
                        downsample = nn.Sequential(
         | 
| 281 | 
            +
                            nn.Conv2d(
         | 
| 282 | 
            +
                                self.inplanes,
         | 
| 283 | 
            +
                                planes * block.expansion,
         | 
| 284 | 
            +
                                kernel_size=1,
         | 
| 285 | 
            +
                                stride=stride,
         | 
| 286 | 
            +
                                bias=False,
         | 
| 287 | 
            +
                            ),
         | 
| 288 | 
            +
                            nn.BatchNorm2d(planes * block.expansion),
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
                    layers = []
         | 
| 291 | 
            +
                    layers.append(
         | 
| 292 | 
            +
                        block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
         | 
| 293 | 
            +
                    )
         | 
| 294 | 
            +
                    self.inplanes = planes
         | 
| 295 | 
            +
                    for _ in range(1, num_blocks):
         | 
| 296 | 
            +
                        layers.append(block(self.inplanes, planes, use_se=self.use_se))
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    return nn.Sequential(*layers)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def forward(self, x):
         | 
| 301 | 
            +
                    x = self.conv1(x)
         | 
| 302 | 
            +
                    x = self.bn1(x)
         | 
| 303 | 
            +
                    x = self.prelu(x)
         | 
| 304 | 
            +
                    x = self.maxpool(x)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    x = self.layer1(x)
         | 
| 307 | 
            +
                    x = self.layer2(x)
         | 
| 308 | 
            +
                    x = self.layer3(x)
         | 
| 309 | 
            +
                    x = self.layer4(x)
         | 
| 310 | 
            +
                    x = self.bn4(x)
         | 
| 311 | 
            +
                    x = self.dropout(x)
         | 
| 312 | 
            +
                    x = x.view(x.size(0), -1)
         | 
| 313 | 
            +
                    x = self.fc5(x)
         | 
| 314 | 
            +
                    x = self.bn5(x)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    return x
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def freeze(self):
         | 
| 319 | 
            +
                    self.eval()
         | 
| 320 | 
            +
                    for param in self.parameters():
         | 
| 321 | 
            +
                        param.requires_grad = False
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            if __name__ == "__main__":
         | 
| 325 | 
            +
                model = ResNetArcFace()
         | 
| 326 | 
            +
                model.cuda()
         | 
| 327 | 
            +
                model.eval()
         | 
| 328 | 
            +
                # model.eval()
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                set1 = [
         | 
| 331 | 
            +
                    "./debug/face_debug/gt/head_gt_0.png",
         | 
| 332 | 
            +
                    "./debug/face_debug/gt/head_gt_1.png",
         | 
| 333 | 
            +
                    "./debug/face_debug/gt/head_gt_2.png",
         | 
| 334 | 
            +
                    "./debug/face_debug/gt/head_gt_3.png",
         | 
| 335 | 
            +
                    "./debug/face_debug/gt/head_gt_4.png",
         | 
| 336 | 
            +
                    "./debug/face_debug/gt/head_gt_5.png",
         | 
| 337 | 
            +
                    "./debug/face_debug/gt/head_gt_6.png",
         | 
| 338 | 
            +
                ]
         | 
| 339 | 
            +
                import cv2
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                img_set1 = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in set1]
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                F1_list = []
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                f1_scores = []
         | 
| 346 | 
            +
                for img in img_set1:
         | 
| 347 | 
            +
                    img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0) / 255.0
         | 
| 348 | 
            +
                    img = img.cuda()
         | 
| 349 | 
            +
                    F1 = model(img)
         | 
| 350 | 
            +
                    F1_list.append(F1)
         | 
| 351 | 
            +
                for i in range(len(F1_list)):
         | 
| 352 | 
            +
                    for j in range(len(F1_list)):
         | 
| 353 | 
            +
                        f1_scores.append(F.l1_loss(F1_list[i], F1_list[j]))
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                print(len(f1_scores))
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                f1_scores = torch.tensor(f1_scores)
         | 
| 358 | 
            +
                print(f1_scores)
         | 
| 359 | 
            +
                f1_scores = f1_scores.view(len(F1_list), len(F1_list))
         | 
| 360 | 
            +
                print(f1_scores)
         | 
    	
        LHM/models/block.py
    ADDED
    
    | @@ -0,0 +1,124 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch.nn as nn
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .modulate import ModLN
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class BasicBlock(nn.Module):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Transformer block that is in its simplest form.
         | 
| 24 | 
            +
                Designed for PF-LRM architecture.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                # Block contains a self-attention layer and an MLP
         | 
| 27 | 
            +
                def __init__(self, inner_dim: int, num_heads: int, eps: float,
         | 
| 28 | 
            +
                             attn_drop: float = 0., attn_bias: bool = False,
         | 
| 29 | 
            +
                             mlp_ratio: float = 4., mlp_drop: float = 0.):
         | 
| 30 | 
            +
                    super().__init__()
         | 
| 31 | 
            +
                    self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
         | 
| 32 | 
            +
                    self.self_attn = nn.MultiheadAttention(
         | 
| 33 | 
            +
                        embed_dim=inner_dim, num_heads=num_heads,
         | 
| 34 | 
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         | 
| 35 | 
            +
                    self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
         | 
| 36 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 37 | 
            +
                        nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
         | 
| 38 | 
            +
                        nn.GELU(),
         | 
| 39 | 
            +
                        nn.Dropout(mlp_drop),
         | 
| 40 | 
            +
                        nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
         | 
| 41 | 
            +
                        nn.Dropout(mlp_drop),
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def forward(self, x):
         | 
| 45 | 
            +
                    # x: [N, L, D]
         | 
| 46 | 
            +
                    before_sa = self.norm1(x)
         | 
| 47 | 
            +
                    x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
         | 
| 48 | 
            +
                    x = x + self.mlp(self.norm2(x))
         | 
| 49 | 
            +
                    return x
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class ConditionBlock(nn.Module):
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                Transformer block that takes in a cross-attention condition.
         | 
| 55 | 
            +
                Designed for SparseLRM architecture.
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                # Block contains a cross-attention layer, a self-attention layer, and an MLP
         | 
| 58 | 
            +
                def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
         | 
| 59 | 
            +
                             attn_drop: float = 0., attn_bias: bool = False,
         | 
| 60 | 
            +
                             mlp_ratio: float = 4., mlp_drop: float = 0.):
         | 
| 61 | 
            +
                    super().__init__()
         | 
| 62 | 
            +
                    self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
         | 
| 63 | 
            +
                    self.cross_attn = nn.MultiheadAttention(
         | 
| 64 | 
            +
                        embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
         | 
| 65 | 
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         | 
| 66 | 
            +
                    self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
         | 
| 67 | 
            +
                    self.self_attn = nn.MultiheadAttention(
         | 
| 68 | 
            +
                        embed_dim=inner_dim, num_heads=num_heads,
         | 
| 69 | 
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         | 
| 70 | 
            +
                    self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
         | 
| 71 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 72 | 
            +
                        nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
         | 
| 73 | 
            +
                        nn.GELU(),
         | 
| 74 | 
            +
                        nn.Dropout(mlp_drop),
         | 
| 75 | 
            +
                        nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
         | 
| 76 | 
            +
                        nn.Dropout(mlp_drop),
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def forward(self, x, cond):
         | 
| 80 | 
            +
                    # x: [N, L, D]
         | 
| 81 | 
            +
                    # cond: [N, L_cond, D_cond]
         | 
| 82 | 
            +
                    x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
         | 
| 83 | 
            +
                    before_sa = self.norm2(x)
         | 
| 84 | 
            +
                    x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
         | 
| 85 | 
            +
                    x = x + self.mlp(self.norm3(x))
         | 
| 86 | 
            +
                    return x
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            class ConditionModulationBlock(nn.Module):
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
         | 
| 92 | 
            +
                Designed for raw LRM architecture.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                # Block contains a cross-attention layer, a self-attention layer, and an MLP
         | 
| 95 | 
            +
                def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
         | 
| 96 | 
            +
                             attn_drop: float = 0., attn_bias: bool = False,
         | 
| 97 | 
            +
                             mlp_ratio: float = 4., mlp_drop: float = 0.):
         | 
| 98 | 
            +
                    super().__init__()
         | 
| 99 | 
            +
                    self.norm1 = ModLN(inner_dim, mod_dim, eps)
         | 
| 100 | 
            +
                    self.cross_attn = nn.MultiheadAttention(
         | 
| 101 | 
            +
                        embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
         | 
| 102 | 
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         | 
| 103 | 
            +
                    self.norm2 = ModLN(inner_dim, mod_dim, eps)
         | 
| 104 | 
            +
                    self.self_attn = nn.MultiheadAttention(
         | 
| 105 | 
            +
                        embed_dim=inner_dim, num_heads=num_heads,
         | 
| 106 | 
            +
                        dropout=attn_drop, bias=attn_bias, batch_first=True)
         | 
| 107 | 
            +
                    self.norm3 = ModLN(inner_dim, mod_dim, eps)
         | 
| 108 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 109 | 
            +
                        nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
         | 
| 110 | 
            +
                        nn.GELU(),
         | 
| 111 | 
            +
                        nn.Dropout(mlp_drop),
         | 
| 112 | 
            +
                        nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
         | 
| 113 | 
            +
                        nn.Dropout(mlp_drop),
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def forward(self, x, cond, mod):
         | 
| 117 | 
            +
                    # x: [N, L, D]
         | 
| 118 | 
            +
                    # cond: [N, L_cond, D_cond]
         | 
| 119 | 
            +
                    # mod: [N, D_mod]
         | 
| 120 | 
            +
                    x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
         | 
| 121 | 
            +
                    before_sa = self.norm2(x, mod)
         | 
| 122 | 
            +
                    x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
         | 
| 123 | 
            +
                    x = x + self.mlp(self.norm3(x, mod))
         | 
| 124 | 
            +
                    return x
         | 
    	
        LHM/models/discriminator.py
    ADDED
    
    | @@ -0,0 +1,120 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Ported from Paella
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 9 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import functools
         | 
| 12 | 
            +
            # import torch.nn as nn
         | 
| 13 | 
            +
            from taming.modules.util import ActNorm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
         | 
| 17 | 
            +
            class Discriminator(ModelMixin, ConfigMixin):
         | 
| 18 | 
            +
                @register_to_config
         | 
| 19 | 
            +
                def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
         | 
| 20 | 
            +
                    super().__init__()
         | 
| 21 | 
            +
                    d = max(depth - 3, 3)
         | 
| 22 | 
            +
                    layers = [
         | 
| 23 | 
            +
                        nn.utils.spectral_norm(
         | 
| 24 | 
            +
                            nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
         | 
| 25 | 
            +
                        ),
         | 
| 26 | 
            +
                        nn.LeakyReLU(0.2),
         | 
| 27 | 
            +
                    ]
         | 
| 28 | 
            +
                    for i in range(depth - 1):
         | 
| 29 | 
            +
                        c_in = hidden_channels // (2 ** max((d - i), 0))
         | 
| 30 | 
            +
                        c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
         | 
| 31 | 
            +
                        layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
         | 
| 32 | 
            +
                        layers.append(nn.InstanceNorm2d(c_out))
         | 
| 33 | 
            +
                        layers.append(nn.LeakyReLU(0.2))
         | 
| 34 | 
            +
                    self.encoder = nn.Sequential(*layers)
         | 
| 35 | 
            +
                    self.shuffle = nn.Conv2d(
         | 
| 36 | 
            +
                        (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    # self.logits = nn.Sigmoid()
         | 
| 39 | 
            +
                    
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                def forward(self, x, cond=None):
         | 
| 42 | 
            +
                    x = self.encoder(x)
         | 
| 43 | 
            +
                    if cond is not None:
         | 
| 44 | 
            +
                        cond = cond.view(
         | 
| 45 | 
            +
                            cond.size(0),
         | 
| 46 | 
            +
                            cond.size(1),
         | 
| 47 | 
            +
                            1,
         | 
| 48 | 
            +
                            1,
         | 
| 49 | 
            +
                        ).expand(-1, -1, x.size(-2), x.size(-1))
         | 
| 50 | 
            +
                        x = torch.cat([x, cond], dim=1)
         | 
| 51 | 
            +
                    x = self.shuffle(x)
         | 
| 52 | 
            +
                    # x = self.logits(x)
         | 
| 53 | 
            +
                    return x
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def weights_init(m):
         | 
| 59 | 
            +
                classname = m.__class__.__name__
         | 
| 60 | 
            +
                if classname.find('Conv') != -1:
         | 
| 61 | 
            +
                    nn.init.normal_(m.weight.data, 0.0, 0.02)
         | 
| 62 | 
            +
                elif classname.find('BatchNorm') != -1:
         | 
| 63 | 
            +
                    nn.init.normal_(m.weight.data, 1.0, 0.02)
         | 
| 64 | 
            +
                    nn.init.constant_(m.bias.data, 0)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class NLayerDiscriminator(nn.Module):
         | 
| 68 | 
            +
                """Defines a PatchGAN discriminator as in Pix2Pix
         | 
| 69 | 
            +
                    --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
         | 
| 72 | 
            +
                    """Construct a PatchGAN discriminator
         | 
| 73 | 
            +
                    Parameters:
         | 
| 74 | 
            +
                        input_nc (int)  -- the number of channels in input images
         | 
| 75 | 
            +
                        ndf (int)       -- the number of filters in the last conv layer
         | 
| 76 | 
            +
                        n_layers (int)  -- the number of conv layers in the discriminator
         | 
| 77 | 
            +
                        norm_layer      -- normalization layer
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    super(NLayerDiscriminator, self).__init__()
         | 
| 80 | 
            +
                    if not use_actnorm:
         | 
| 81 | 
            +
                        # norm_layer = nn.BatchNorm2d
         | 
| 82 | 
            +
                        norm_layer = nn.InstanceNorm2d
         | 
| 83 | 
            +
                    else:
         | 
| 84 | 
            +
                        norm_layer = ActNorm
         | 
| 85 | 
            +
                    if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
         | 
| 86 | 
            +
                        # use_bias = norm_layer.func != nn.BatchNorm2d
         | 
| 87 | 
            +
                        use_bias = norm_layer.func != nn.InstanceNorm2d
         | 
| 88 | 
            +
                    else:
         | 
| 89 | 
            +
                        # use_bias = norm_layer != nn.BatchNorm2d
         | 
| 90 | 
            +
                        use_bias = norm_layer != nn.InstanceNorm2d
         | 
| 91 | 
            +
                    
         | 
| 92 | 
            +
                    kw = 4
         | 
| 93 | 
            +
                    padw = 1
         | 
| 94 | 
            +
                    sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]
         | 
| 95 | 
            +
                    nf_mult = 1
         | 
| 96 | 
            +
                    nf_mult_prev = 1
         | 
| 97 | 
            +
                    for n in range(1, n_layers):  # gradually increase the number of filters
         | 
| 98 | 
            +
                        nf_mult_prev = nf_mult
         | 
| 99 | 
            +
                        nf_mult = min(2 ** n, 8)
         | 
| 100 | 
            +
                        sequence += [
         | 
| 101 | 
            +
                            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
         | 
| 102 | 
            +
                            norm_layer(ndf * nf_mult),
         | 
| 103 | 
            +
                            nn.LeakyReLU(0.2, False)
         | 
| 104 | 
            +
                        ]
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    nf_mult_prev = nf_mult
         | 
| 107 | 
            +
                    nf_mult = min(2 ** n_layers, 8)
         | 
| 108 | 
            +
                    sequence += [
         | 
| 109 | 
            +
                        nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
         | 
| 110 | 
            +
                        norm_layer(ndf * nf_mult),
         | 
| 111 | 
            +
                        nn.LeakyReLU(0.2, False)
         | 
| 112 | 
            +
                    ]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    sequence += [
         | 
| 115 | 
            +
                        nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
         | 
| 116 | 
            +
                    self.main = nn.Sequential(*sequence)
         | 
| 117 | 
            +
                    
         | 
| 118 | 
            +
                def forward(self, input):
         | 
| 119 | 
            +
                    """Standard forward."""
         | 
| 120 | 
            +
                    return self.main(input)
         | 
    	
        LHM/models/embedder.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn as nn
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class CameraEmbedder(nn.Module):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                Embed camera features to a high-dimensional vector.
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                Reference:
         | 
| 25 | 
            +
                DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                def __init__(self, raw_dim: int, embed_dim: int):
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 30 | 
            +
                        nn.Linear(raw_dim, embed_dim),
         | 
| 31 | 
            +
                        nn.SiLU(),
         | 
| 32 | 
            +
                        nn.Linear(embed_dim, embed_dim),
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                @torch.compile
         | 
| 36 | 
            +
                def forward(self, x):
         | 
| 37 | 
            +
                    return self.mlp(x)
         | 
    	
        LHM/models/encoders/__init__.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Empty
         | 
    	
        LHM/models/encoders/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (192 Bytes). View file | 
|  | 
    	
        LHM/models/encoders/__pycache__/dinov2_fusion_wrapper.cpython-310.pyc
    ADDED
    
    | Binary file (5 kB). View file | 
|  | 
    	
        LHM/models/encoders/__pycache__/sapiens_warpper.cpython-310.pyc
    ADDED
    
    | Binary file (9.24 kB). View file | 
|  | 
    	
        LHM/models/encoders/dino_wrapper.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn as nn
         | 
| 18 | 
            +
            from transformers import ViTImageProcessor, ViTModel
         | 
| 19 | 
            +
            from accelerate.logging import get_logger
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            logger = get_logger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class DinoWrapper(nn.Module):
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                Dino v1 wrapper using huggingface transformer implementation.
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                def __init__(self, model_name: str, freeze: bool = True, encoder_feat_dim: int = 384):
         | 
| 30 | 
            +
                    super().__init__()
         | 
| 31 | 
            +
                    self.model, self.processor = self._build_dino(model_name)
         | 
| 32 | 
            +
                    if freeze:
         | 
| 33 | 
            +
                        self._freeze()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                @torch.compile
         | 
| 36 | 
            +
                def forward_model(self, inputs):
         | 
| 37 | 
            +
                    return self.model(**inputs, interpolate_pos_encoding=True)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def forward(self, image):
         | 
| 40 | 
            +
                    # image: [N, C, H, W], on cpu
         | 
| 41 | 
            +
                    # RGB image with [0,1] scale and properly sized
         | 
| 42 | 
            +
                    inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
         | 
| 43 | 
            +
                    # This resampling of positional embedding uses bicubic interpolation
         | 
| 44 | 
            +
                    outputs = self.forward_model(inputs)
         | 
| 45 | 
            +
                    last_hidden_states = outputs.last_hidden_state
         | 
| 46 | 
            +
                    return last_hidden_states
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def _freeze(self):
         | 
| 49 | 
            +
                    logger.warning(f"======== Freezing DinoWrapper ========")
         | 
| 50 | 
            +
                    self.model.eval()
         | 
| 51 | 
            +
                    for name, param in self.model.named_parameters():
         | 
| 52 | 
            +
                        param.requires_grad = False
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @staticmethod
         | 
| 55 | 
            +
                def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
         | 
| 56 | 
            +
                    import requests
         | 
| 57 | 
            +
                    try:
         | 
| 58 | 
            +
                        model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
         | 
| 59 | 
            +
                        processor = ViTImageProcessor.from_pretrained(model_name)
         | 
| 60 | 
            +
                        return model, processor
         | 
| 61 | 
            +
                    except requests.exceptions.ProxyError as err:
         | 
| 62 | 
            +
                        if proxy_error_retries > 0:
         | 
| 63 | 
            +
                            print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
         | 
| 64 | 
            +
                            import time
         | 
| 65 | 
            +
                            time.sleep(proxy_error_cooldown)
         | 
| 66 | 
            +
                            return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
         | 
| 67 | 
            +
                        else:
         | 
| 68 | 
            +
                            raise err
         | 
    	
        LHM/models/encoders/dinov2/__init__.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023-2024, Zexin He
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     https://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Empty
         | 
    	
        LHM/models/encoders/dinov2/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (199 Bytes). View file | 
|  | 
    	
        LHM/models/encoders/dinov2/hub/__init__.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
    	
        LHM/models/encoders/dinov2/hub/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (203 Bytes). View file | 
|  | 
    	
        LHM/models/encoders/dinov2/hub/__pycache__/backbones.cpython-310.pyc
    ADDED
    
    | Binary file (4.47 kB). View file | 
|  | 
    	
        LHM/models/encoders/dinov2/hub/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (1.82 kB). View file | 
|  | 
    	
        LHM/models/encoders/dinov2/hub/backbones.py
    ADDED
    
    | @@ -0,0 +1,166 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            from typing import Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Weights(Enum):
         | 
| 15 | 
            +
                LVD142M = "LVD142M"
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _make_dinov2_model(
         | 
| 19 | 
            +
                *,
         | 
| 20 | 
            +
                arch_name: str = "vit_large",
         | 
| 21 | 
            +
                img_size: int = 518,
         | 
| 22 | 
            +
                patch_size: int = 14,
         | 
| 23 | 
            +
                init_values: float = 1.0,
         | 
| 24 | 
            +
                ffn_layer: str = "mlp",
         | 
| 25 | 
            +
                block_chunks: int = 0,
         | 
| 26 | 
            +
                num_register_tokens: int = 0,
         | 
| 27 | 
            +
                interpolate_antialias: bool = False,
         | 
| 28 | 
            +
                interpolate_offset: float = 0.1,
         | 
| 29 | 
            +
                pretrained: bool = True,
         | 
| 30 | 
            +
                weights: Union[Weights, str] = Weights.LVD142M,
         | 
| 31 | 
            +
                **kwargs,
         | 
| 32 | 
            +
            ):
         | 
| 33 | 
            +
                from ..models import vision_transformer as vits
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                if isinstance(weights, str):
         | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        weights = Weights[weights]
         | 
| 38 | 
            +
                    except KeyError:
         | 
| 39 | 
            +
                        raise AssertionError(f"Unsupported weights: {weights}")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                model_base_name = _make_dinov2_model_name(arch_name, patch_size)
         | 
| 42 | 
            +
                vit_kwargs = dict(
         | 
| 43 | 
            +
                    img_size=img_size,
         | 
| 44 | 
            +
                    patch_size=patch_size,
         | 
| 45 | 
            +
                    init_values=init_values,
         | 
| 46 | 
            +
                    ffn_layer=ffn_layer,
         | 
| 47 | 
            +
                    block_chunks=block_chunks,
         | 
| 48 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 49 | 
            +
                    interpolate_antialias=interpolate_antialias,
         | 
| 50 | 
            +
                    interpolate_offset=interpolate_offset,
         | 
| 51 | 
            +
                )
         | 
| 52 | 
            +
                vit_kwargs.update(**kwargs)
         | 
| 53 | 
            +
                model = vits.__dict__[arch_name](**vit_kwargs)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if pretrained:
         | 
| 56 | 
            +
                    model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
         | 
| 57 | 
            +
                    url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
         | 
| 58 | 
            +
                    state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
         | 
| 59 | 
            +
                    # ********** Modified by Zexin He in 2023-2024 **********
         | 
| 60 | 
            +
                    state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k}  # DDP concern
         | 
| 61 | 
            +
                    if vit_kwargs.get("modulation_dim") is not None:
         | 
| 62 | 
            +
                        state_dict = {
         | 
| 63 | 
            +
                            k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
         | 
| 64 | 
            +
                            for k, v in state_dict.items()
         | 
| 65 | 
            +
                        }
         | 
| 66 | 
            +
                        model.load_state_dict(state_dict, strict=False)
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        model.load_state_dict(state_dict, strict=True)
         | 
| 69 | 
            +
                    # ********************************************************
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                return model
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
                return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
                DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                return _make_dinov2_model(
         | 
| 100 | 
            +
                    arch_name="vit_giant2",
         | 
| 101 | 
            +
                    ffn_layer="swiglufused",
         | 
| 102 | 
            +
                    weights=weights,
         | 
| 103 | 
            +
                    pretrained=pretrained,
         | 
| 104 | 
            +
                    **kwargs,
         | 
| 105 | 
            +
                )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                return _make_dinov2_model(
         | 
| 113 | 
            +
                    arch_name="vit_small",
         | 
| 114 | 
            +
                    pretrained=pretrained,
         | 
| 115 | 
            +
                    weights=weights,
         | 
| 116 | 
            +
                    num_register_tokens=4,
         | 
| 117 | 
            +
                    interpolate_antialias=True,
         | 
| 118 | 
            +
                    interpolate_offset=0.0,
         | 
| 119 | 
            +
                    **kwargs,
         | 
| 120 | 
            +
                )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 124 | 
            +
                """
         | 
| 125 | 
            +
                DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                return _make_dinov2_model(
         | 
| 128 | 
            +
                    arch_name="vit_base",
         | 
| 129 | 
            +
                    pretrained=pretrained,
         | 
| 130 | 
            +
                    weights=weights,
         | 
| 131 | 
            +
                    num_register_tokens=4,
         | 
| 132 | 
            +
                    interpolate_antialias=True,
         | 
| 133 | 
            +
                    interpolate_offset=0.0,
         | 
| 134 | 
            +
                    **kwargs,
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 139 | 
            +
                """
         | 
| 140 | 
            +
                DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 141 | 
            +
                """
         | 
| 142 | 
            +
                return _make_dinov2_model(
         | 
| 143 | 
            +
                    arch_name="vit_large",
         | 
| 144 | 
            +
                    pretrained=pretrained,
         | 
| 145 | 
            +
                    weights=weights,
         | 
| 146 | 
            +
                    num_register_tokens=4,
         | 
| 147 | 
            +
                    interpolate_antialias=True,
         | 
| 148 | 
            +
                    interpolate_offset=0.0,
         | 
| 149 | 
            +
                    **kwargs,
         | 
| 150 | 
            +
                )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
         | 
| 156 | 
            +
                """
         | 
| 157 | 
            +
                return _make_dinov2_model(
         | 
| 158 | 
            +
                    arch_name="vit_giant2",
         | 
| 159 | 
            +
                    ffn_layer="swiglufused",
         | 
| 160 | 
            +
                    weights=weights,
         | 
| 161 | 
            +
                    pretrained=pretrained,
         | 
| 162 | 
            +
                    num_register_tokens=4,
         | 
| 163 | 
            +
                    interpolate_antialias=True,
         | 
| 164 | 
            +
                    interpolate_offset=0.0,
         | 
| 165 | 
            +
                    **kwargs,
         | 
| 166 | 
            +
                )
         | 
    	
        LHM/models/encoders/dinov2/hub/classifiers.py
    ADDED
    
    | @@ -0,0 +1,268 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            from typing import Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .backbones import _make_dinov2_model
         | 
| 13 | 
            +
            from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Weights(Enum):
         | 
| 17 | 
            +
                IMAGENET1K = "IMAGENET1K"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def _make_dinov2_linear_classification_head(
         | 
| 21 | 
            +
                *,
         | 
| 22 | 
            +
                arch_name: str = "vit_large",
         | 
| 23 | 
            +
                patch_size: int = 14,
         | 
| 24 | 
            +
                embed_dim: int = 1024,
         | 
| 25 | 
            +
                layers: int = 4,
         | 
| 26 | 
            +
                pretrained: bool = True,
         | 
| 27 | 
            +
                weights: Union[Weights, str] = Weights.IMAGENET1K,
         | 
| 28 | 
            +
                num_register_tokens: int = 0,
         | 
| 29 | 
            +
                **kwargs,
         | 
| 30 | 
            +
            ):
         | 
| 31 | 
            +
                if layers not in (1, 4):
         | 
| 32 | 
            +
                    raise AssertionError(f"Unsupported number of layers: {layers}")
         | 
| 33 | 
            +
                if isinstance(weights, str):
         | 
| 34 | 
            +
                    try:
         | 
| 35 | 
            +
                        weights = Weights[weights]
         | 
| 36 | 
            +
                    except KeyError:
         | 
| 37 | 
            +
                        raise AssertionError(f"Unsupported weights: {weights}")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                if pretrained:
         | 
| 42 | 
            +
                    model_base_name = _make_dinov2_model_name(arch_name, patch_size)
         | 
| 43 | 
            +
                    model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
         | 
| 44 | 
            +
                    layers_str = str(layers) if layers == 4 else ""
         | 
| 45 | 
            +
                    url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
         | 
| 46 | 
            +
                    state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
         | 
| 47 | 
            +
                    linear_head.load_state_dict(state_dict, strict=True)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return linear_head
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            class _LinearClassifierWrapper(nn.Module):
         | 
| 53 | 
            +
                def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
         | 
| 54 | 
            +
                    super().__init__()
         | 
| 55 | 
            +
                    self.backbone = backbone
         | 
| 56 | 
            +
                    self.linear_head = linear_head
         | 
| 57 | 
            +
                    self.layers = layers
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x):
         | 
| 60 | 
            +
                    if self.layers == 1:
         | 
| 61 | 
            +
                        x = self.backbone.forward_features(x)
         | 
| 62 | 
            +
                        cls_token = x["x_norm_clstoken"]
         | 
| 63 | 
            +
                        patch_tokens = x["x_norm_patchtokens"]
         | 
| 64 | 
            +
                        # fmt: off
         | 
| 65 | 
            +
                        linear_input = torch.cat([
         | 
| 66 | 
            +
                            cls_token,
         | 
| 67 | 
            +
                            patch_tokens.mean(dim=1),
         | 
| 68 | 
            +
                        ], dim=1)
         | 
| 69 | 
            +
                        # fmt: on
         | 
| 70 | 
            +
                    elif self.layers == 4:
         | 
| 71 | 
            +
                        x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
         | 
| 72 | 
            +
                        # fmt: off
         | 
| 73 | 
            +
                        linear_input = torch.cat([
         | 
| 74 | 
            +
                            x[0][1],
         | 
| 75 | 
            +
                            x[1][1],
         | 
| 76 | 
            +
                            x[2][1],
         | 
| 77 | 
            +
                            x[3][1],
         | 
| 78 | 
            +
                            x[3][0].mean(dim=1),
         | 
| 79 | 
            +
                        ], dim=1)
         | 
| 80 | 
            +
                        # fmt: on
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        assert False, f"Unsupported number of layers: {self.layers}"
         | 
| 83 | 
            +
                    return self.linear_head(linear_input)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def _make_dinov2_linear_classifier(
         | 
| 87 | 
            +
                *,
         | 
| 88 | 
            +
                arch_name: str = "vit_large",
         | 
| 89 | 
            +
                layers: int = 4,
         | 
| 90 | 
            +
                pretrained: bool = True,
         | 
| 91 | 
            +
                weights: Union[Weights, str] = Weights.IMAGENET1K,
         | 
| 92 | 
            +
                num_register_tokens: int = 0,
         | 
| 93 | 
            +
                interpolate_antialias: bool = False,
         | 
| 94 | 
            +
                interpolate_offset: float = 0.1,
         | 
| 95 | 
            +
                **kwargs,
         | 
| 96 | 
            +
            ):
         | 
| 97 | 
            +
                backbone = _make_dinov2_model(
         | 
| 98 | 
            +
                    arch_name=arch_name,
         | 
| 99 | 
            +
                    pretrained=pretrained,
         | 
| 100 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 101 | 
            +
                    interpolate_antialias=interpolate_antialias,
         | 
| 102 | 
            +
                    interpolate_offset=interpolate_offset,
         | 
| 103 | 
            +
                    **kwargs,
         | 
| 104 | 
            +
                )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                embed_dim = backbone.embed_dim
         | 
| 107 | 
            +
                patch_size = backbone.patch_size
         | 
| 108 | 
            +
                linear_head = _make_dinov2_linear_classification_head(
         | 
| 109 | 
            +
                    arch_name=arch_name,
         | 
| 110 | 
            +
                    patch_size=patch_size,
         | 
| 111 | 
            +
                    embed_dim=embed_dim,
         | 
| 112 | 
            +
                    layers=layers,
         | 
| 113 | 
            +
                    pretrained=pretrained,
         | 
| 114 | 
            +
                    weights=weights,
         | 
| 115 | 
            +
                    num_register_tokens=num_register_tokens,
         | 
| 116 | 
            +
                )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def dinov2_vits14_lc(
         | 
| 122 | 
            +
                *,
         | 
| 123 | 
            +
                layers: int = 4,
         | 
| 124 | 
            +
                pretrained: bool = True,
         | 
| 125 | 
            +
                weights: Union[Weights, str] = Weights.IMAGENET1K,
         | 
| 126 | 
            +
                **kwargs,
         | 
| 127 | 
            +
            ):
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 132 | 
            +
                    arch_name="vit_small",
         | 
| 133 | 
            +
                    layers=layers,
         | 
| 134 | 
            +
                    pretrained=pretrained,
         | 
| 135 | 
            +
                    weights=weights,
         | 
| 136 | 
            +
                    **kwargs,
         | 
| 137 | 
            +
                )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def dinov2_vitb14_lc(
         | 
| 141 | 
            +
                *,
         | 
| 142 | 
            +
                layers: int = 4,
         | 
| 143 | 
            +
                pretrained: bool = True,
         | 
| 144 | 
            +
                weights: Union[Weights, str] = Weights.IMAGENET1K,
         | 
| 145 | 
            +
                **kwargs,
         | 
| 146 | 
            +
            ):
         | 
| 147 | 
            +
                """
         | 
| 148 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 151 | 
            +
                    arch_name="vit_base",
         | 
| 152 | 
            +
                    layers=layers,
         | 
| 153 | 
            +
                    pretrained=pretrained,
         | 
| 154 | 
            +
                    weights=weights,
         | 
| 155 | 
            +
                    **kwargs,
         | 
| 156 | 
            +
                )
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def dinov2_vitl14_lc(
         | 
| 160 | 
            +
                *,
         | 
| 161 | 
            +
                layers: int = 4,
         | 
| 162 | 
            +
                pretrained: bool = True,
         | 
| 163 | 
            +
                weights: Union[Weights, str] = Weights.IMAGENET1K,
         | 
| 164 | 
            +
                **kwargs,
         | 
| 165 | 
            +
            ):
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 168 | 
            +
                """
         | 
| 169 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 170 | 
            +
                    arch_name="vit_large",
         | 
| 171 | 
            +
                    layers=layers,
         | 
| 172 | 
            +
                    pretrained=pretrained,
         | 
| 173 | 
            +
                    weights=weights,
         | 
| 174 | 
            +
                    **kwargs,
         | 
| 175 | 
            +
                )
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            def dinov2_vitg14_lc(
         | 
| 179 | 
            +
                *,
         | 
| 180 | 
            +
                layers: int = 4,
         | 
| 181 | 
            +
                pretrained: bool = True,
         | 
| 182 | 
            +
                weights: Union[Weights, str] = Weights.IMAGENET1K,
         | 
| 183 | 
            +
                **kwargs,
         | 
| 184 | 
            +
            ):
         | 
| 185 | 
            +
                """
         | 
| 186 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 189 | 
            +
                    arch_name="vit_giant2",
         | 
| 190 | 
            +
                    layers=layers,
         | 
| 191 | 
            +
                    ffn_layer="swiglufused",
         | 
| 192 | 
            +
                    pretrained=pretrained,
         | 
| 193 | 
            +
                    weights=weights,
         | 
| 194 | 
            +
                    **kwargs,
         | 
| 195 | 
            +
                )
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            def dinov2_vits14_reg_lc(
         | 
| 199 | 
            +
                *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
         | 
| 200 | 
            +
            ):
         | 
| 201 | 
            +
                """
         | 
| 202 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 203 | 
            +
                """
         | 
| 204 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 205 | 
            +
                    arch_name="vit_small",
         | 
| 206 | 
            +
                    layers=layers,
         | 
| 207 | 
            +
                    pretrained=pretrained,
         | 
| 208 | 
            +
                    weights=weights,
         | 
| 209 | 
            +
                    num_register_tokens=4,
         | 
| 210 | 
            +
                    interpolate_antialias=True,
         | 
| 211 | 
            +
                    interpolate_offset=0.0,
         | 
| 212 | 
            +
                    **kwargs,
         | 
| 213 | 
            +
                )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            def dinov2_vitb14_reg_lc(
         | 
| 217 | 
            +
                *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
         | 
| 218 | 
            +
            ):
         | 
| 219 | 
            +
                """
         | 
| 220 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 221 | 
            +
                """
         | 
| 222 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 223 | 
            +
                    arch_name="vit_base",
         | 
| 224 | 
            +
                    layers=layers,
         | 
| 225 | 
            +
                    pretrained=pretrained,
         | 
| 226 | 
            +
                    weights=weights,
         | 
| 227 | 
            +
                    num_register_tokens=4,
         | 
| 228 | 
            +
                    interpolate_antialias=True,
         | 
| 229 | 
            +
                    interpolate_offset=0.0,
         | 
| 230 | 
            +
                    **kwargs,
         | 
| 231 | 
            +
                )
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            def dinov2_vitl14_reg_lc(
         | 
| 235 | 
            +
                *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
         | 
| 236 | 
            +
            ):
         | 
| 237 | 
            +
                """
         | 
| 238 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 239 | 
            +
                """
         | 
| 240 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 241 | 
            +
                    arch_name="vit_large",
         | 
| 242 | 
            +
                    layers=layers,
         | 
| 243 | 
            +
                    pretrained=pretrained,
         | 
| 244 | 
            +
                    weights=weights,
         | 
| 245 | 
            +
                    num_register_tokens=4,
         | 
| 246 | 
            +
                    interpolate_antialias=True,
         | 
| 247 | 
            +
                    interpolate_offset=0.0,
         | 
| 248 | 
            +
                    **kwargs,
         | 
| 249 | 
            +
                )
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def dinov2_vitg14_reg_lc(
         | 
| 253 | 
            +
                *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
         | 
| 254 | 
            +
            ):
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
         | 
| 257 | 
            +
                """
         | 
| 258 | 
            +
                return _make_dinov2_linear_classifier(
         | 
| 259 | 
            +
                    arch_name="vit_giant2",
         | 
| 260 | 
            +
                    layers=layers,
         | 
| 261 | 
            +
                    ffn_layer="swiglufused",
         | 
| 262 | 
            +
                    pretrained=pretrained,
         | 
| 263 | 
            +
                    weights=weights,
         | 
| 264 | 
            +
                    num_register_tokens=4,
         | 
| 265 | 
            +
                    interpolate_antialias=True,
         | 
| 266 | 
            +
                    interpolate_offset=0.0,
         | 
| 267 | 
            +
                    **kwargs,
         | 
| 268 | 
            +
                )
         | 
    	
        LHM/models/encoders/dinov2/hub/depth/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .decode_heads import BNHead, DPTHead
         | 
| 7 | 
            +
            from .encoder_decoder import DepthEncoderDecoder
         | 
    	
        LHM/models/encoders/dinov2/hub/depth/decode_heads.py
    ADDED
    
    | @@ -0,0 +1,747 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import copy
         | 
| 7 | 
            +
            from functools import partial
         | 
| 8 | 
            +
            import math
         | 
| 9 | 
            +
            import warnings
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .ops import resize
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            # XXX: (Untested) replacement for mmcv.imdenormalize()
         | 
| 18 | 
            +
            def _imdenormalize(img, mean, std, to_bgr=True):
         | 
| 19 | 
            +
                import numpy as np
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                mean = mean.reshape(1, -1).astype(np.float64)
         | 
| 22 | 
            +
                std = std.reshape(1, -1).astype(np.float64)
         | 
| 23 | 
            +
                img = (img * std) + mean
         | 
| 24 | 
            +
                if to_bgr:
         | 
| 25 | 
            +
                    img = img[::-1]
         | 
| 26 | 
            +
                return img
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class DepthBaseDecodeHead(nn.Module):
         | 
| 30 | 
            +
                """Base class for BaseDecodeHead.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                Args:
         | 
| 33 | 
            +
                    in_channels (List): Input channels.
         | 
| 34 | 
            +
                    channels (int): Channels after modules, before conv_depth.
         | 
| 35 | 
            +
                    conv_layer (nn.Module): Conv layers. Default: None.
         | 
| 36 | 
            +
                    act_layer (nn.Module): Activation layers. Default: nn.ReLU.
         | 
| 37 | 
            +
                    loss_decode (dict): Config of decode loss.
         | 
| 38 | 
            +
                        Default: ().
         | 
| 39 | 
            +
                    sampler (dict|None): The config of depth map sampler.
         | 
| 40 | 
            +
                        Default: None.
         | 
| 41 | 
            +
                    align_corners (bool): align_corners argument of F.interpolate.
         | 
| 42 | 
            +
                        Default: False.
         | 
| 43 | 
            +
                    min_depth (int): Min depth in dataset setting.
         | 
| 44 | 
            +
                        Default: 1e-3.
         | 
| 45 | 
            +
                    max_depth (int): Max depth in dataset setting.
         | 
| 46 | 
            +
                        Default: None.
         | 
| 47 | 
            +
                    norm_layer (dict|None): Norm layers.
         | 
| 48 | 
            +
                        Default: None.
         | 
| 49 | 
            +
                    classify (bool): Whether predict depth in a cls.-reg. manner.
         | 
| 50 | 
            +
                        Default: False.
         | 
| 51 | 
            +
                    n_bins (int): The number of bins used in cls. step.
         | 
| 52 | 
            +
                        Default: 256.
         | 
| 53 | 
            +
                    bins_strategy (str): The discrete strategy used in cls. step.
         | 
| 54 | 
            +
                        Default: 'UD'.
         | 
| 55 | 
            +
                    norm_strategy (str): The norm strategy on cls. probability
         | 
| 56 | 
            +
                        distribution. Default: 'linear'
         | 
| 57 | 
            +
                    scale_up (str): Whether predict depth in a scale-up manner.
         | 
| 58 | 
            +
                        Default: False.
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def __init__(
         | 
| 62 | 
            +
                    self,
         | 
| 63 | 
            +
                    in_channels,
         | 
| 64 | 
            +
                    conv_layer=None,
         | 
| 65 | 
            +
                    act_layer=nn.ReLU,
         | 
| 66 | 
            +
                    channels=96,
         | 
| 67 | 
            +
                    loss_decode=(),
         | 
| 68 | 
            +
                    sampler=None,
         | 
| 69 | 
            +
                    align_corners=False,
         | 
| 70 | 
            +
                    min_depth=1e-3,
         | 
| 71 | 
            +
                    max_depth=None,
         | 
| 72 | 
            +
                    norm_layer=None,
         | 
| 73 | 
            +
                    classify=False,
         | 
| 74 | 
            +
                    n_bins=256,
         | 
| 75 | 
            +
                    bins_strategy="UD",
         | 
| 76 | 
            +
                    norm_strategy="linear",
         | 
| 77 | 
            +
                    scale_up=False,
         | 
| 78 | 
            +
                ):
         | 
| 79 | 
            +
                    super(DepthBaseDecodeHead, self).__init__()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    self.in_channels = in_channels
         | 
| 82 | 
            +
                    self.channels = channels
         | 
| 83 | 
            +
                    self.conf_layer = conv_layer
         | 
| 84 | 
            +
                    self.act_layer = act_layer
         | 
| 85 | 
            +
                    self.loss_decode = loss_decode
         | 
| 86 | 
            +
                    self.align_corners = align_corners
         | 
| 87 | 
            +
                    self.min_depth = min_depth
         | 
| 88 | 
            +
                    self.max_depth = max_depth
         | 
| 89 | 
            +
                    self.norm_layer = norm_layer
         | 
| 90 | 
            +
                    self.classify = classify
         | 
| 91 | 
            +
                    self.n_bins = n_bins
         | 
| 92 | 
            +
                    self.scale_up = scale_up
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if self.classify:
         | 
| 95 | 
            +
                        assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
         | 
| 96 | 
            +
                        assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                        self.bins_strategy = bins_strategy
         | 
| 99 | 
            +
                        self.norm_strategy = norm_strategy
         | 
| 100 | 
            +
                        self.softmax = nn.Softmax(dim=1)
         | 
| 101 | 
            +
                        self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
         | 
| 102 | 
            +
                    else:
         | 
| 103 | 
            +
                        self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    self.relu = nn.ReLU()
         | 
| 106 | 
            +
                    self.sigmoid = nn.Sigmoid()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, inputs, img_metas):
         | 
| 109 | 
            +
                    """Placeholder of forward function."""
         | 
| 110 | 
            +
                    pass
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward_train(self, img, inputs, img_metas, depth_gt):
         | 
| 113 | 
            +
                    """Forward function for training.
         | 
| 114 | 
            +
                    Args:
         | 
| 115 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 116 | 
            +
                        img_metas (list[dict]): List of image info dict where each dict
         | 
| 117 | 
            +
                            has: 'img_shape', 'scale_factor', 'flip', and may also contain
         | 
| 118 | 
            +
                            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
         | 
| 119 | 
            +
                            For details on the values of these keys see
         | 
| 120 | 
            +
                            `depth/datasets/pipelines/formatting.py:Collect`.
         | 
| 121 | 
            +
                        depth_gt (Tensor): GT depth
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    Returns:
         | 
| 124 | 
            +
                        dict[str, Tensor]: a dictionary of loss components
         | 
| 125 | 
            +
                    """
         | 
| 126 | 
            +
                    depth_pred = self.forward(inputs, img_metas)
         | 
| 127 | 
            +
                    losses = self.losses(depth_pred, depth_gt)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
         | 
| 130 | 
            +
                    losses.update(**log_imgs)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    return losses
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def forward_test(self, inputs, img_metas):
         | 
| 135 | 
            +
                    """Forward function for testing.
         | 
| 136 | 
            +
                    Args:
         | 
| 137 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 138 | 
            +
                        img_metas (list[dict]): List of image info dict where each dict
         | 
| 139 | 
            +
                            has: 'img_shape', 'scale_factor', 'flip', and may also contain
         | 
| 140 | 
            +
                            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
         | 
| 141 | 
            +
                            For details on the values of these keys see
         | 
| 142 | 
            +
                            `depth/datasets/pipelines/formatting.py:Collect`.
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    Returns:
         | 
| 145 | 
            +
                        Tensor: Output depth map.
         | 
| 146 | 
            +
                    """
         | 
| 147 | 
            +
                    return self.forward(inputs, img_metas)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def depth_pred(self, feat):
         | 
| 150 | 
            +
                    """Prediction each pixel."""
         | 
| 151 | 
            +
                    if self.classify:
         | 
| 152 | 
            +
                        logit = self.conv_depth(feat)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        if self.bins_strategy == "UD":
         | 
| 155 | 
            +
                            bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
         | 
| 156 | 
            +
                        elif self.bins_strategy == "SID":
         | 
| 157 | 
            +
                            bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        # following Adabins, default linear
         | 
| 160 | 
            +
                        if self.norm_strategy == "linear":
         | 
| 161 | 
            +
                            logit = torch.relu(logit)
         | 
| 162 | 
            +
                            eps = 0.1
         | 
| 163 | 
            +
                            logit = logit + eps
         | 
| 164 | 
            +
                            logit = logit / logit.sum(dim=1, keepdim=True)
         | 
| 165 | 
            +
                        elif self.norm_strategy == "softmax":
         | 
| 166 | 
            +
                            logit = torch.softmax(logit, dim=1)
         | 
| 167 | 
            +
                        elif self.norm_strategy == "sigmoid":
         | 
| 168 | 
            +
                            logit = torch.sigmoid(logit)
         | 
| 169 | 
            +
                            logit = logit / logit.sum(dim=1, keepdim=True)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    else:
         | 
| 174 | 
            +
                        if self.scale_up:
         | 
| 175 | 
            +
                            output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
         | 
| 176 | 
            +
                        else:
         | 
| 177 | 
            +
                            output = self.relu(self.conv_depth(feat)) + self.min_depth
         | 
| 178 | 
            +
                    return output
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def losses(self, depth_pred, depth_gt):
         | 
| 181 | 
            +
                    """Compute depth loss."""
         | 
| 182 | 
            +
                    loss = dict()
         | 
| 183 | 
            +
                    depth_pred = resize(
         | 
| 184 | 
            +
                        input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
                    if not isinstance(self.loss_decode, nn.ModuleList):
         | 
| 187 | 
            +
                        losses_decode = [self.loss_decode]
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        losses_decode = self.loss_decode
         | 
| 190 | 
            +
                    for loss_decode in losses_decode:
         | 
| 191 | 
            +
                        if loss_decode.loss_name not in loss:
         | 
| 192 | 
            +
                            loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
         | 
| 193 | 
            +
                        else:
         | 
| 194 | 
            +
                            loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
         | 
| 195 | 
            +
                    return loss
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def log_images(self, img_path, depth_pred, depth_gt, img_meta):
         | 
| 198 | 
            +
                    import numpy as np
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
         | 
| 201 | 
            +
                    show_img = show_img.numpy().astype(np.float32)
         | 
| 202 | 
            +
                    show_img = _imdenormalize(
         | 
| 203 | 
            +
                        show_img,
         | 
| 204 | 
            +
                        img_meta["img_norm_cfg"]["mean"],
         | 
| 205 | 
            +
                        img_meta["img_norm_cfg"]["std"],
         | 
| 206 | 
            +
                        img_meta["img_norm_cfg"]["to_rgb"],
         | 
| 207 | 
            +
                    )
         | 
| 208 | 
            +
                    show_img = np.clip(show_img, 0, 255)
         | 
| 209 | 
            +
                    show_img = show_img.astype(np.uint8)
         | 
| 210 | 
            +
                    show_img = show_img[:, :, ::-1]
         | 
| 211 | 
            +
                    show_img = show_img.transpose(0, 2, 1)
         | 
| 212 | 
            +
                    show_img = show_img.transpose(1, 0, 2)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    depth_pred = depth_pred / torch.max(depth_pred)
         | 
| 215 | 
            +
                    depth_gt = depth_gt / torch.max(depth_gt)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
         | 
| 218 | 
            +
                    depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            class BNHead(DepthBaseDecodeHead):
         | 
| 224 | 
            +
                """Just a batchnorm."""
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
         | 
| 227 | 
            +
                    super().__init__(**kwargs)
         | 
| 228 | 
            +
                    self.input_transform = input_transform
         | 
| 229 | 
            +
                    self.in_index = in_index
         | 
| 230 | 
            +
                    self.upsample = upsample
         | 
| 231 | 
            +
                    # self.bn = nn.SyncBatchNorm(self.in_channels)
         | 
| 232 | 
            +
                    if self.classify:
         | 
| 233 | 
            +
                        self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
         | 
| 234 | 
            +
                    else:
         | 
| 235 | 
            +
                        self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def _transform_inputs(self, inputs):
         | 
| 238 | 
            +
                    """Transform inputs for decoder.
         | 
| 239 | 
            +
                    Args:
         | 
| 240 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 241 | 
            +
                    Returns:
         | 
| 242 | 
            +
                        Tensor: The transformed inputs
         | 
| 243 | 
            +
                    """
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    if "concat" in self.input_transform:
         | 
| 246 | 
            +
                        inputs = [inputs[i] for i in self.in_index]
         | 
| 247 | 
            +
                        if "resize" in self.input_transform:
         | 
| 248 | 
            +
                            inputs = [
         | 
| 249 | 
            +
                                resize(
         | 
| 250 | 
            +
                                    input=x,
         | 
| 251 | 
            +
                                    size=[s * self.upsample for s in inputs[0].shape[2:]],
         | 
| 252 | 
            +
                                    mode="bilinear",
         | 
| 253 | 
            +
                                    align_corners=self.align_corners,
         | 
| 254 | 
            +
                                )
         | 
| 255 | 
            +
                                for x in inputs
         | 
| 256 | 
            +
                            ]
         | 
| 257 | 
            +
                        inputs = torch.cat(inputs, dim=1)
         | 
| 258 | 
            +
                    elif self.input_transform == "multiple_select":
         | 
| 259 | 
            +
                        inputs = [inputs[i] for i in self.in_index]
         | 
| 260 | 
            +
                    else:
         | 
| 261 | 
            +
                        inputs = inputs[self.in_index]
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    return inputs
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def _forward_feature(self, inputs, img_metas=None, **kwargs):
         | 
| 266 | 
            +
                    """Forward function for feature maps before classifying each pixel with
         | 
| 267 | 
            +
                    ``self.cls_seg`` fc.
         | 
| 268 | 
            +
                    Args:
         | 
| 269 | 
            +
                        inputs (list[Tensor]): List of multi-level img features.
         | 
| 270 | 
            +
                    Returns:
         | 
| 271 | 
            +
                        feats (Tensor): A tensor of shape (batch_size, self.channels,
         | 
| 272 | 
            +
                            H, W) which is feature map for last layer of decoder head.
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    # accept lists (for cls token)
         | 
| 275 | 
            +
                    inputs = list(inputs)
         | 
| 276 | 
            +
                    for i, x in enumerate(inputs):
         | 
| 277 | 
            +
                        if len(x) == 2:
         | 
| 278 | 
            +
                            x, cls_token = x[0], x[1]
         | 
| 279 | 
            +
                            if len(x.shape) == 2:
         | 
| 280 | 
            +
                                x = x[:, :, None, None]
         | 
| 281 | 
            +
                            cls_token = cls_token[:, :, None, None].expand_as(x)
         | 
| 282 | 
            +
                            inputs[i] = torch.cat((x, cls_token), 1)
         | 
| 283 | 
            +
                        else:
         | 
| 284 | 
            +
                            x = x[0]
         | 
| 285 | 
            +
                            if len(x.shape) == 2:
         | 
| 286 | 
            +
                                x = x[:, :, None, None]
         | 
| 287 | 
            +
                            inputs[i] = x
         | 
| 288 | 
            +
                    x = self._transform_inputs(inputs)
         | 
| 289 | 
            +
                    # feats = self.bn(x)
         | 
| 290 | 
            +
                    return x
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def forward(self, inputs, img_metas=None, **kwargs):
         | 
| 293 | 
            +
                    """Forward function."""
         | 
| 294 | 
            +
                    output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
         | 
| 295 | 
            +
                    output = self.depth_pred(output)
         | 
| 296 | 
            +
                    return output
         | 
| 297 | 
            +
             | 
| 298 | 
            +
             | 
| 299 | 
            +
            class ConvModule(nn.Module):
         | 
| 300 | 
            +
                """A conv block that bundles conv/norm/activation layers.
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                This block simplifies the usage of convolution layers, which are commonly
         | 
| 303 | 
            +
                used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
         | 
| 304 | 
            +
                It is based upon three build methods: `build_conv_layer()`,
         | 
| 305 | 
            +
                `build_norm_layer()` and `build_activation_layer()`.
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                Besides, we add some additional features in this module.
         | 
| 308 | 
            +
                1. Automatically set `bias` of the conv layer.
         | 
| 309 | 
            +
                2. Spectral norm is supported.
         | 
| 310 | 
            +
                3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
         | 
| 311 | 
            +
                supports zero and circular padding, and we add "reflect" padding mode.
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                Args:
         | 
| 314 | 
            +
                    in_channels (int): Number of channels in the input feature map.
         | 
| 315 | 
            +
                        Same as that in ``nn._ConvNd``.
         | 
| 316 | 
            +
                    out_channels (int): Number of channels produced by the convolution.
         | 
| 317 | 
            +
                        Same as that in ``nn._ConvNd``.
         | 
| 318 | 
            +
                    kernel_size (int | tuple[int]): Size of the convolving kernel.
         | 
| 319 | 
            +
                        Same as that in ``nn._ConvNd``.
         | 
| 320 | 
            +
                    stride (int | tuple[int]): Stride of the convolution.
         | 
| 321 | 
            +
                        Same as that in ``nn._ConvNd``.
         | 
| 322 | 
            +
                    padding (int | tuple[int]): Zero-padding added to both sides of
         | 
| 323 | 
            +
                        the input. Same as that in ``nn._ConvNd``.
         | 
| 324 | 
            +
                    dilation (int | tuple[int]): Spacing between kernel elements.
         | 
| 325 | 
            +
                        Same as that in ``nn._ConvNd``.
         | 
| 326 | 
            +
                    groups (int): Number of blocked connections from input channels to
         | 
| 327 | 
            +
                        output channels. Same as that in ``nn._ConvNd``.
         | 
| 328 | 
            +
                    bias (bool | str): If specified as `auto`, it will be decided by the
         | 
| 329 | 
            +
                        norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
         | 
| 330 | 
            +
                        False. Default: "auto".
         | 
| 331 | 
            +
                    conv_layer (nn.Module): Convolution layer. Default: None,
         | 
| 332 | 
            +
                        which means using conv2d.
         | 
| 333 | 
            +
                    norm_layer (nn.Module): Normalization layer. Default: None.
         | 
| 334 | 
            +
                    act_layer (nn.Module): Activation layer. Default: nn.ReLU.
         | 
| 335 | 
            +
                    inplace (bool): Whether to use inplace mode for activation.
         | 
| 336 | 
            +
                        Default: True.
         | 
| 337 | 
            +
                    with_spectral_norm (bool): Whether use spectral norm in conv module.
         | 
| 338 | 
            +
                        Default: False.
         | 
| 339 | 
            +
                    padding_mode (str): If the `padding_mode` has not been supported by
         | 
| 340 | 
            +
                        current `Conv2d` in PyTorch, we will use our own padding layer
         | 
| 341 | 
            +
                        instead. Currently, we support ['zeros', 'circular'] with official
         | 
| 342 | 
            +
                        implementation and ['reflect'] with our own implementation.
         | 
| 343 | 
            +
                        Default: 'zeros'.
         | 
| 344 | 
            +
                    order (tuple[str]): The order of conv/norm/activation layers. It is a
         | 
| 345 | 
            +
                        sequence of "conv", "norm" and "act". Common examples are
         | 
| 346 | 
            +
                        ("conv", "norm", "act") and ("act", "conv", "norm").
         | 
| 347 | 
            +
                        Default: ('conv', 'norm', 'act').
         | 
| 348 | 
            +
                """
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                _abbr_ = "conv_block"
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def __init__(
         | 
| 353 | 
            +
                    self,
         | 
| 354 | 
            +
                    in_channels,
         | 
| 355 | 
            +
                    out_channels,
         | 
| 356 | 
            +
                    kernel_size,
         | 
| 357 | 
            +
                    stride=1,
         | 
| 358 | 
            +
                    padding=0,
         | 
| 359 | 
            +
                    dilation=1,
         | 
| 360 | 
            +
                    groups=1,
         | 
| 361 | 
            +
                    bias="auto",
         | 
| 362 | 
            +
                    conv_layer=nn.Conv2d,
         | 
| 363 | 
            +
                    norm_layer=None,
         | 
| 364 | 
            +
                    act_layer=nn.ReLU,
         | 
| 365 | 
            +
                    inplace=True,
         | 
| 366 | 
            +
                    with_spectral_norm=False,
         | 
| 367 | 
            +
                    padding_mode="zeros",
         | 
| 368 | 
            +
                    order=("conv", "norm", "act"),
         | 
| 369 | 
            +
                ):
         | 
| 370 | 
            +
                    super(ConvModule, self).__init__()
         | 
| 371 | 
            +
                    official_padding_mode = ["zeros", "circular"]
         | 
| 372 | 
            +
                    self.conv_layer = conv_layer
         | 
| 373 | 
            +
                    self.norm_layer = norm_layer
         | 
| 374 | 
            +
                    self.act_layer = act_layer
         | 
| 375 | 
            +
                    self.inplace = inplace
         | 
| 376 | 
            +
                    self.with_spectral_norm = with_spectral_norm
         | 
| 377 | 
            +
                    self.with_explicit_padding = padding_mode not in official_padding_mode
         | 
| 378 | 
            +
                    self.order = order
         | 
| 379 | 
            +
                    assert isinstance(self.order, tuple) and len(self.order) == 3
         | 
| 380 | 
            +
                    assert set(order) == set(["conv", "norm", "act"])
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    self.with_norm = norm_layer is not None
         | 
| 383 | 
            +
                    self.with_activation = act_layer is not None
         | 
| 384 | 
            +
                    # if the conv layer is before a norm layer, bias is unnecessary.
         | 
| 385 | 
            +
                    if bias == "auto":
         | 
| 386 | 
            +
                        bias = not self.with_norm
         | 
| 387 | 
            +
                    self.with_bias = bias
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    if self.with_explicit_padding:
         | 
| 390 | 
            +
                        if padding_mode == "zeros":
         | 
| 391 | 
            +
                            padding_layer = nn.ZeroPad2d
         | 
| 392 | 
            +
                        else:
         | 
| 393 | 
            +
                            raise AssertionError(f"Unsupported padding mode: {padding_mode}")
         | 
| 394 | 
            +
                        self.pad = padding_layer(padding)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    # reset padding to 0 for conv module
         | 
| 397 | 
            +
                    conv_padding = 0 if self.with_explicit_padding else padding
         | 
| 398 | 
            +
                    # build convolution layer
         | 
| 399 | 
            +
                    self.conv = self.conv_layer(
         | 
| 400 | 
            +
                        in_channels,
         | 
| 401 | 
            +
                        out_channels,
         | 
| 402 | 
            +
                        kernel_size,
         | 
| 403 | 
            +
                        stride=stride,
         | 
| 404 | 
            +
                        padding=conv_padding,
         | 
| 405 | 
            +
                        dilation=dilation,
         | 
| 406 | 
            +
                        groups=groups,
         | 
| 407 | 
            +
                        bias=bias,
         | 
| 408 | 
            +
                    )
         | 
| 409 | 
            +
                    # export the attributes of self.conv to a higher level for convenience
         | 
| 410 | 
            +
                    self.in_channels = self.conv.in_channels
         | 
| 411 | 
            +
                    self.out_channels = self.conv.out_channels
         | 
| 412 | 
            +
                    self.kernel_size = self.conv.kernel_size
         | 
| 413 | 
            +
                    self.stride = self.conv.stride
         | 
| 414 | 
            +
                    self.padding = padding
         | 
| 415 | 
            +
                    self.dilation = self.conv.dilation
         | 
| 416 | 
            +
                    self.transposed = self.conv.transposed
         | 
| 417 | 
            +
                    self.output_padding = self.conv.output_padding
         | 
| 418 | 
            +
                    self.groups = self.conv.groups
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                    if self.with_spectral_norm:
         | 
| 421 | 
            +
                        self.conv = nn.utils.spectral_norm(self.conv)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    # build normalization layers
         | 
| 424 | 
            +
                    if self.with_norm:
         | 
| 425 | 
            +
                        # norm layer is after conv layer
         | 
| 426 | 
            +
                        if order.index("norm") > order.index("conv"):
         | 
| 427 | 
            +
                            norm_channels = out_channels
         | 
| 428 | 
            +
                        else:
         | 
| 429 | 
            +
                            norm_channels = in_channels
         | 
| 430 | 
            +
                        norm = partial(norm_layer, num_features=norm_channels)
         | 
| 431 | 
            +
                        self.add_module("norm", norm)
         | 
| 432 | 
            +
                        if self.with_bias:
         | 
| 433 | 
            +
                            from torch.nnModules.batchnorm import _BatchNorm
         | 
| 434 | 
            +
                            from torch.nnModules.instancenorm import _InstanceNorm
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                            if isinstance(norm, (_BatchNorm, _InstanceNorm)):
         | 
| 437 | 
            +
                                warnings.warn("Unnecessary conv bias before batch/instance norm")
         | 
| 438 | 
            +
                    else:
         | 
| 439 | 
            +
                        self.norm_name = None
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    # build activation layer
         | 
| 442 | 
            +
                    if self.with_activation:
         | 
| 443 | 
            +
                        # nn.Tanh has no 'inplace' argument
         | 
| 444 | 
            +
                        # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
         | 
| 445 | 
            +
                        if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
         | 
| 446 | 
            +
                            act_layer = partial(act_layer, inplace=inplace)
         | 
| 447 | 
            +
                        self.activate = act_layer()
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    # Use msra init by default
         | 
| 450 | 
            +
                    self.init_weights()
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                @property
         | 
| 453 | 
            +
                def norm(self):
         | 
| 454 | 
            +
                    if self.norm_name:
         | 
| 455 | 
            +
                        return getattr(self, self.norm_name)
         | 
| 456 | 
            +
                    else:
         | 
| 457 | 
            +
                        return None
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                def init_weights(self):
         | 
| 460 | 
            +
                    # 1. It is mainly for customized conv layers with their own
         | 
| 461 | 
            +
                    #    initialization manners by calling their own ``init_weights()``,
         | 
| 462 | 
            +
                    #    and we do not want ConvModule to override the initialization.
         | 
| 463 | 
            +
                    # 2. For customized conv layers without their own initialization
         | 
| 464 | 
            +
                    #    manners (that is, they don't have their own ``init_weights()``)
         | 
| 465 | 
            +
                    #    and PyTorch's conv layers, they will be initialized by
         | 
| 466 | 
            +
                    #    this method with default ``kaiming_init``.
         | 
| 467 | 
            +
                    # Note: For PyTorch's conv layers, they will be overwritten by our
         | 
| 468 | 
            +
                    #    initialization implementation using default ``kaiming_init``.
         | 
| 469 | 
            +
                    if not hasattr(self.conv, "init_weights"):
         | 
| 470 | 
            +
                        if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
         | 
| 471 | 
            +
                            nonlinearity = "leaky_relu"
         | 
| 472 | 
            +
                            a = 0.01  # XXX: default negative_slope
         | 
| 473 | 
            +
                        else:
         | 
| 474 | 
            +
                            nonlinearity = "relu"
         | 
| 475 | 
            +
                            a = 0
         | 
| 476 | 
            +
                        if hasattr(self.conv, "weight") and self.conv.weight is not None:
         | 
| 477 | 
            +
                            nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
         | 
| 478 | 
            +
                        if hasattr(self.conv, "bias") and self.conv.bias is not None:
         | 
| 479 | 
            +
                            nn.init.constant_(self.conv.bias, 0)
         | 
| 480 | 
            +
                    if self.with_norm:
         | 
| 481 | 
            +
                        if hasattr(self.norm, "weight") and self.norm.weight is not None:
         | 
| 482 | 
            +
                            nn.init.constant_(self.norm.weight, 1)
         | 
| 483 | 
            +
                        if hasattr(self.norm, "bias") and self.norm.bias is not None:
         | 
| 484 | 
            +
                            nn.init.constant_(self.norm.bias, 0)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                def forward(self, x, activate=True, norm=True):
         | 
| 487 | 
            +
                    for layer in self.order:
         | 
| 488 | 
            +
                        if layer == "conv":
         | 
| 489 | 
            +
                            if self.with_explicit_padding:
         | 
| 490 | 
            +
                                x = self.pad(x)
         | 
| 491 | 
            +
                            x = self.conv(x)
         | 
| 492 | 
            +
                        elif layer == "norm" and norm and self.with_norm:
         | 
| 493 | 
            +
                            x = self.norm(x)
         | 
| 494 | 
            +
                        elif layer == "act" and activate and self.with_activation:
         | 
| 495 | 
            +
                            x = self.activate(x)
         | 
| 496 | 
            +
                    return x
         | 
| 497 | 
            +
             | 
| 498 | 
            +
             | 
| 499 | 
            +
            class Interpolate(nn.Module):
         | 
| 500 | 
            +
                def __init__(self, scale_factor, mode, align_corners=False):
         | 
| 501 | 
            +
                    super(Interpolate, self).__init__()
         | 
| 502 | 
            +
                    self.interp = nn.functional.interpolate
         | 
| 503 | 
            +
                    self.scale_factor = scale_factor
         | 
| 504 | 
            +
                    self.mode = mode
         | 
| 505 | 
            +
                    self.align_corners = align_corners
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                def forward(self, x):
         | 
| 508 | 
            +
                    x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
         | 
| 509 | 
            +
                    return x
         | 
| 510 | 
            +
             | 
| 511 | 
            +
             | 
| 512 | 
            +
            class HeadDepth(nn.Module):
         | 
| 513 | 
            +
                def __init__(self, features):
         | 
| 514 | 
            +
                    super(HeadDepth, self).__init__()
         | 
| 515 | 
            +
                    self.head = nn.Sequential(
         | 
| 516 | 
            +
                        nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
         | 
| 517 | 
            +
                        Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
         | 
| 518 | 
            +
                        nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
         | 
| 519 | 
            +
                        nn.ReLU(),
         | 
| 520 | 
            +
                        nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
         | 
| 521 | 
            +
                    )
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                def forward(self, x):
         | 
| 524 | 
            +
                    x = self.head(x)
         | 
| 525 | 
            +
                    return x
         | 
| 526 | 
            +
             | 
| 527 | 
            +
             | 
| 528 | 
            +
            class ReassembleBlocks(nn.Module):
         | 
| 529 | 
            +
                """ViTPostProcessBlock, process cls_token in ViT backbone output and
         | 
| 530 | 
            +
                rearrange the feature vector to feature map.
         | 
| 531 | 
            +
                Args:
         | 
| 532 | 
            +
                    in_channels (int): ViT feature channels. Default: 768.
         | 
| 533 | 
            +
                    out_channels (List): output channels of each stage.
         | 
| 534 | 
            +
                        Default: [96, 192, 384, 768].
         | 
| 535 | 
            +
                    readout_type (str): Type of readout operation. Default: 'ignore'.
         | 
| 536 | 
            +
                    patch_size (int): The patch size. Default: 16.
         | 
| 537 | 
            +
                """
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
         | 
| 540 | 
            +
                    super(ReassembleBlocks, self).__init__()
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    assert readout_type in ["ignore", "add", "project"]
         | 
| 543 | 
            +
                    self.readout_type = readout_type
         | 
| 544 | 
            +
                    self.patch_size = patch_size
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                    self.projects = nn.ModuleList(
         | 
| 547 | 
            +
                        [
         | 
| 548 | 
            +
                            ConvModule(
         | 
| 549 | 
            +
                                in_channels=in_channels,
         | 
| 550 | 
            +
                                out_channels=out_channel,
         | 
| 551 | 
            +
                                kernel_size=1,
         | 
| 552 | 
            +
                                act_layer=None,
         | 
| 553 | 
            +
                            )
         | 
| 554 | 
            +
                            for out_channel in out_channels
         | 
| 555 | 
            +
                        ]
         | 
| 556 | 
            +
                    )
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    self.resize_layers = nn.ModuleList(
         | 
| 559 | 
            +
                        [
         | 
| 560 | 
            +
                            nn.ConvTranspose2d(
         | 
| 561 | 
            +
                                in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
         | 
| 562 | 
            +
                            ),
         | 
| 563 | 
            +
                            nn.ConvTranspose2d(
         | 
| 564 | 
            +
                                in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
         | 
| 565 | 
            +
                            ),
         | 
| 566 | 
            +
                            nn.Identity(),
         | 
| 567 | 
            +
                            nn.Conv2d(
         | 
| 568 | 
            +
                                in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
         | 
| 569 | 
            +
                            ),
         | 
| 570 | 
            +
                        ]
         | 
| 571 | 
            +
                    )
         | 
| 572 | 
            +
                    if self.readout_type == "project":
         | 
| 573 | 
            +
                        self.readout_projects = nn.ModuleList()
         | 
| 574 | 
            +
                        for _ in range(len(self.projects)):
         | 
| 575 | 
            +
                            self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                def forward(self, inputs):
         | 
| 578 | 
            +
                    assert isinstance(inputs, list)
         | 
| 579 | 
            +
                    out = []
         | 
| 580 | 
            +
                    for i, x in enumerate(inputs):
         | 
| 581 | 
            +
                        assert len(x) == 2
         | 
| 582 | 
            +
                        x, cls_token = x[0], x[1]
         | 
| 583 | 
            +
                        feature_shape = x.shape
         | 
| 584 | 
            +
                        if self.readout_type == "project":
         | 
| 585 | 
            +
                            x = x.flatten(2).permute((0, 2, 1))
         | 
| 586 | 
            +
                            readout = cls_token.unsqueeze(1).expand_as(x)
         | 
| 587 | 
            +
                            x = self.readout_projects[i](torch.cat((x, readout), -1))
         | 
| 588 | 
            +
                            x = x.permute(0, 2, 1).reshape(feature_shape)
         | 
| 589 | 
            +
                        elif self.readout_type == "add":
         | 
| 590 | 
            +
                            x = x.flatten(2) + cls_token.unsqueeze(-1)
         | 
| 591 | 
            +
                            x = x.reshape(feature_shape)
         | 
| 592 | 
            +
                        else:
         | 
| 593 | 
            +
                            pass
         | 
| 594 | 
            +
                        x = self.projects[i](x)
         | 
| 595 | 
            +
                        x = self.resize_layers[i](x)
         | 
| 596 | 
            +
                        out.append(x)
         | 
| 597 | 
            +
                    return out
         | 
| 598 | 
            +
             | 
| 599 | 
            +
             | 
| 600 | 
            +
            class PreActResidualConvUnit(nn.Module):
         | 
| 601 | 
            +
                """ResidualConvUnit, pre-activate residual unit.
         | 
| 602 | 
            +
                Args:
         | 
| 603 | 
            +
                    in_channels (int): number of channels in the input feature map.
         | 
| 604 | 
            +
                    act_layer (nn.Module): activation layer.
         | 
| 605 | 
            +
                    norm_layer (nn.Module): norm layer.
         | 
| 606 | 
            +
                    stride (int): stride of the first block. Default: 1
         | 
| 607 | 
            +
                    dilation (int): dilation rate for convs layers. Default: 1.
         | 
| 608 | 
            +
                """
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
         | 
| 611 | 
            +
                    super(PreActResidualConvUnit, self).__init__()
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    self.conv1 = ConvModule(
         | 
| 614 | 
            +
                        in_channels,
         | 
| 615 | 
            +
                        in_channels,
         | 
| 616 | 
            +
                        3,
         | 
| 617 | 
            +
                        stride=stride,
         | 
| 618 | 
            +
                        padding=dilation,
         | 
| 619 | 
            +
                        dilation=dilation,
         | 
| 620 | 
            +
                        norm_layer=norm_layer,
         | 
| 621 | 
            +
                        act_layer=act_layer,
         | 
| 622 | 
            +
                        bias=False,
         | 
| 623 | 
            +
                        order=("act", "conv", "norm"),
         | 
| 624 | 
            +
                    )
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                    self.conv2 = ConvModule(
         | 
| 627 | 
            +
                        in_channels,
         | 
| 628 | 
            +
                        in_channels,
         | 
| 629 | 
            +
                        3,
         | 
| 630 | 
            +
                        padding=1,
         | 
| 631 | 
            +
                        norm_layer=norm_layer,
         | 
| 632 | 
            +
                        act_layer=act_layer,
         | 
| 633 | 
            +
                        bias=False,
         | 
| 634 | 
            +
                        order=("act", "conv", "norm"),
         | 
| 635 | 
            +
                    )
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                def forward(self, inputs):
         | 
| 638 | 
            +
                    inputs_ = inputs.clone()
         | 
| 639 | 
            +
                    x = self.conv1(inputs)
         | 
| 640 | 
            +
                    x = self.conv2(x)
         | 
| 641 | 
            +
                    return x + inputs_
         | 
| 642 | 
            +
             | 
| 643 | 
            +
             | 
| 644 | 
            +
            class FeatureFusionBlock(nn.Module):
         | 
| 645 | 
            +
                """FeatureFusionBlock, merge feature map from different stages.
         | 
| 646 | 
            +
                Args:
         | 
| 647 | 
            +
                    in_channels (int): Input channels.
         | 
| 648 | 
            +
                    act_layer (nn.Module): activation layer for ResidualConvUnit.
         | 
| 649 | 
            +
                    norm_layer (nn.Module): normalization layer.
         | 
| 650 | 
            +
                    expand (bool): Whether expand the channels in post process block.
         | 
| 651 | 
            +
                        Default: False.
         | 
| 652 | 
            +
                    align_corners (bool): align_corner setting for bilinear upsample.
         | 
| 653 | 
            +
                        Default: True.
         | 
| 654 | 
            +
                """
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
         | 
| 657 | 
            +
                    super(FeatureFusionBlock, self).__init__()
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    self.in_channels = in_channels
         | 
| 660 | 
            +
                    self.expand = expand
         | 
| 661 | 
            +
                    self.align_corners = align_corners
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    self.out_channels = in_channels
         | 
| 664 | 
            +
                    if self.expand:
         | 
| 665 | 
            +
                        self.out_channels = in_channels // 2
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    self.res_conv_unit1 = PreActResidualConvUnit(
         | 
| 670 | 
            +
                        in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
         | 
| 671 | 
            +
                    )
         | 
| 672 | 
            +
                    self.res_conv_unit2 = PreActResidualConvUnit(
         | 
| 673 | 
            +
                        in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
         | 
| 674 | 
            +
                    )
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                def forward(self, *inputs):
         | 
| 677 | 
            +
                    x = inputs[0]
         | 
| 678 | 
            +
                    if len(inputs) == 2:
         | 
| 679 | 
            +
                        if x.shape != inputs[1].shape:
         | 
| 680 | 
            +
                            res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
         | 
| 681 | 
            +
                        else:
         | 
| 682 | 
            +
                            res = inputs[1]
         | 
| 683 | 
            +
                        x = x + self.res_conv_unit1(res)
         | 
| 684 | 
            +
                    x = self.res_conv_unit2(x)
         | 
| 685 | 
            +
                    x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
         | 
| 686 | 
            +
                    x = self.project(x)
         | 
| 687 | 
            +
                    return x
         | 
| 688 | 
            +
             | 
| 689 | 
            +
             | 
| 690 | 
            +
            class DPTHead(DepthBaseDecodeHead):
         | 
| 691 | 
            +
                """Vision Transformers for Dense Prediction.
         | 
| 692 | 
            +
                This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
         | 
| 693 | 
            +
                Args:
         | 
| 694 | 
            +
                    embed_dims (int): The embed dimension of the ViT backbone.
         | 
| 695 | 
            +
                        Default: 768.
         | 
| 696 | 
            +
                    post_process_channels (List): Out channels of post process conv
         | 
| 697 | 
            +
                        layers. Default: [96, 192, 384, 768].
         | 
| 698 | 
            +
                    readout_type (str): Type of readout operation. Default: 'ignore'.
         | 
| 699 | 
            +
                    patch_size (int): The patch size. Default: 16.
         | 
| 700 | 
            +
                    expand_channels (bool): Whether expand the channels in post process
         | 
| 701 | 
            +
                        block. Default: False.
         | 
| 702 | 
            +
                """
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                def __init__(
         | 
| 705 | 
            +
                    self,
         | 
| 706 | 
            +
                    embed_dims=768,
         | 
| 707 | 
            +
                    post_process_channels=[96, 192, 384, 768],
         | 
| 708 | 
            +
                    readout_type="ignore",
         | 
| 709 | 
            +
                    patch_size=16,
         | 
| 710 | 
            +
                    expand_channels=False,
         | 
| 711 | 
            +
                    **kwargs,
         | 
| 712 | 
            +
                ):
         | 
| 713 | 
            +
                    super(DPTHead, self).__init__(**kwargs)
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    self.in_channels = self.in_channels
         | 
| 716 | 
            +
                    self.expand_channels = expand_channels
         | 
| 717 | 
            +
                    self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    self.post_process_channels = [
         | 
| 720 | 
            +
                        channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
         | 
| 721 | 
            +
                    ]
         | 
| 722 | 
            +
                    self.convs = nn.ModuleList()
         | 
| 723 | 
            +
                    for channel in self.post_process_channels:
         | 
| 724 | 
            +
                        self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
         | 
| 725 | 
            +
                    self.fusion_blocks = nn.ModuleList()
         | 
| 726 | 
            +
                    for _ in range(len(self.convs)):
         | 
| 727 | 
            +
                        self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
         | 
| 728 | 
            +
                    self.fusion_blocks[0].res_conv_unit1 = None
         | 
| 729 | 
            +
                    self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
         | 
| 730 | 
            +
                    self.num_fusion_blocks = len(self.fusion_blocks)
         | 
| 731 | 
            +
                    self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
         | 
| 732 | 
            +
                    self.num_post_process_channels = len(self.post_process_channels)
         | 
| 733 | 
            +
                    assert self.num_fusion_blocks == self.num_reassemble_blocks
         | 
| 734 | 
            +
                    assert self.num_reassemble_blocks == self.num_post_process_channels
         | 
| 735 | 
            +
                    self.conv_depth = HeadDepth(self.channels)
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                def forward(self, inputs, img_metas):
         | 
| 738 | 
            +
                    assert len(inputs) == self.num_reassemble_blocks
         | 
| 739 | 
            +
                    x = [inp for inp in inputs]
         | 
| 740 | 
            +
                    x = self.reassemble_blocks(x)
         | 
| 741 | 
            +
                    x = [self.convs[i](feature) for i, feature in enumerate(x)]
         | 
| 742 | 
            +
                    out = self.fusion_blocks[0](x[-1])
         | 
| 743 | 
            +
                    for i in range(1, len(self.fusion_blocks)):
         | 
| 744 | 
            +
                        out = self.fusion_blocks[i](out, x[-(i + 1)])
         | 
| 745 | 
            +
                    out = self.project(out)
         | 
| 746 | 
            +
                    out = self.depth_pred(out)
         | 
| 747 | 
            +
                    return out
         | 
    	
        LHM/models/encoders/dinov2/hub/depth/encoder_decoder.py
    ADDED
    
    | @@ -0,0 +1,351 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from collections import OrderedDict
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .ops import resize
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def add_prefix(inputs, prefix):
         | 
| 16 | 
            +
                """Add prefix for dict.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Args:
         | 
| 19 | 
            +
                    inputs (dict): The input dict with str keys.
         | 
| 20 | 
            +
                    prefix (str): The prefix to add.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                Returns:
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    dict: The dict with keys updated with ``prefix``.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                outputs = dict()
         | 
| 28 | 
            +
                for name, value in inputs.items():
         | 
| 29 | 
            +
                    outputs[f"{prefix}.{name}"] = value
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                return outputs
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class DepthEncoderDecoder(nn.Module):
         | 
| 35 | 
            +
                """Encoder Decoder depther.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                EncoderDecoder typically consists of backbone and decode_head.
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def __init__(self, backbone, decode_head):
         | 
| 41 | 
            +
                    super(DepthEncoderDecoder, self).__init__()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    self.backbone = backbone
         | 
| 44 | 
            +
                    self.decode_head = decode_head
         | 
| 45 | 
            +
                    self.align_corners = self.decode_head.align_corners
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def extract_feat(self, img):
         | 
| 48 | 
            +
                    """Extract features from images."""
         | 
| 49 | 
            +
                    return self.backbone(img)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def encode_decode(self, img, img_metas, rescale=True, size=None):
         | 
| 52 | 
            +
                    """Encode images with backbone and decode into a depth estimation
         | 
| 53 | 
            +
                    map of the same size as input."""
         | 
| 54 | 
            +
                    x = self.extract_feat(img)
         | 
| 55 | 
            +
                    out = self._decode_head_forward_test(x, img_metas)
         | 
| 56 | 
            +
                    # crop the pred depth to the certain range.
         | 
| 57 | 
            +
                    out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
         | 
| 58 | 
            +
                    if rescale:
         | 
| 59 | 
            +
                        if size is None:
         | 
| 60 | 
            +
                            if img_metas is not None:
         | 
| 61 | 
            +
                                size = img_metas[0]["ori_shape"][:2]
         | 
| 62 | 
            +
                            else:
         | 
| 63 | 
            +
                                size = img.shape[2:]
         | 
| 64 | 
            +
                        out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
         | 
| 65 | 
            +
                    return out
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
         | 
| 68 | 
            +
                    """Run forward function and calculate loss for decode head in
         | 
| 69 | 
            +
                    training."""
         | 
| 70 | 
            +
                    losses = dict()
         | 
| 71 | 
            +
                    loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
         | 
| 72 | 
            +
                    losses.update(add_prefix(loss_decode, "decode"))
         | 
| 73 | 
            +
                    return losses
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _decode_head_forward_test(self, x, img_metas):
         | 
| 76 | 
            +
                    """Run forward function and calculate loss for decode head in
         | 
| 77 | 
            +
                    inference."""
         | 
| 78 | 
            +
                    depth_pred = self.decode_head.forward_test(x, img_metas)
         | 
| 79 | 
            +
                    return depth_pred
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward_dummy(self, img):
         | 
| 82 | 
            +
                    """Dummy forward function."""
         | 
| 83 | 
            +
                    depth = self.encode_decode(img, None)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    return depth
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward_train(self, img, img_metas, depth_gt, **kwargs):
         | 
| 88 | 
            +
                    """Forward function for training.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    Args:
         | 
| 91 | 
            +
                        img (Tensor): Input images.
         | 
| 92 | 
            +
                        img_metas (list[dict]): List of image info dict where each dict
         | 
| 93 | 
            +
                            has: 'img_shape', 'scale_factor', 'flip', and may also contain
         | 
| 94 | 
            +
                            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
         | 
| 95 | 
            +
                            For details on the values of these keys see
         | 
| 96 | 
            +
                            `depth/datasets/pipelines/formatting.py:Collect`.
         | 
| 97 | 
            +
                        depth_gt (Tensor): Depth gt
         | 
| 98 | 
            +
                            used if the architecture supports depth estimation task.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    Returns:
         | 
| 101 | 
            +
                        dict[str, Tensor]: a dictionary of loss components
         | 
| 102 | 
            +
                    """
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    x = self.extract_feat(img)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    losses = dict()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # the last of x saves the info from neck
         | 
| 109 | 
            +
                    loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    losses.update(loss_decode)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    return losses
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def whole_inference(self, img, img_meta, rescale, size=None):
         | 
| 116 | 
            +
                    """Inference with full image."""
         | 
| 117 | 
            +
                    return self.encode_decode(img, img_meta, rescale, size=size)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def slide_inference(self, img, img_meta, rescale, stride, crop_size):
         | 
| 120 | 
            +
                    """Inference by sliding-window with overlap.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    If h_crop > h_img or w_crop > w_img, the small patch will be used to
         | 
| 123 | 
            +
                    decode without padding.
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    h_stride, w_stride = stride
         | 
| 127 | 
            +
                    h_crop, w_crop = crop_size
         | 
| 128 | 
            +
                    batch_size, _, h_img, w_img = img.size()
         | 
| 129 | 
            +
                    h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
         | 
| 130 | 
            +
                    w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
         | 
| 131 | 
            +
                    preds = img.new_zeros((batch_size, 1, h_img, w_img))
         | 
| 132 | 
            +
                    count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
         | 
| 133 | 
            +
                    for h_idx in range(h_grids):
         | 
| 134 | 
            +
                        for w_idx in range(w_grids):
         | 
| 135 | 
            +
                            y1 = h_idx * h_stride
         | 
| 136 | 
            +
                            x1 = w_idx * w_stride
         | 
| 137 | 
            +
                            y2 = min(y1 + h_crop, h_img)
         | 
| 138 | 
            +
                            x2 = min(x1 + w_crop, w_img)
         | 
| 139 | 
            +
                            y1 = max(y2 - h_crop, 0)
         | 
| 140 | 
            +
                            x1 = max(x2 - w_crop, 0)
         | 
| 141 | 
            +
                            crop_img = img[:, :, y1:y2, x1:x2]
         | 
| 142 | 
            +
                            depth_pred = self.encode_decode(crop_img, img_meta, rescale)
         | 
| 143 | 
            +
                            preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                            count_mat[:, :, y1:y2, x1:x2] += 1
         | 
| 146 | 
            +
                    assert (count_mat == 0).sum() == 0
         | 
| 147 | 
            +
                    if torch.onnx.is_in_onnx_export():
         | 
| 148 | 
            +
                        # cast count_mat to constant while exporting to ONNX
         | 
| 149 | 
            +
                        count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
         | 
| 150 | 
            +
                    preds = preds / count_mat
         | 
| 151 | 
            +
                    return preds
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                def inference(self, img, img_meta, rescale, size=None, mode="whole"):
         | 
| 154 | 
            +
                    """Inference with slide/whole style.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    Args:
         | 
| 157 | 
            +
                        img (Tensor): The input image of shape (N, 3, H, W).
         | 
| 158 | 
            +
                        img_meta (dict): Image info dict where each dict has: 'img_shape',
         | 
| 159 | 
            +
                            'scale_factor', 'flip', and may also contain
         | 
| 160 | 
            +
                            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
         | 
| 161 | 
            +
                            For details on the values of these keys see
         | 
| 162 | 
            +
                            `depth/datasets/pipelines/formatting.py:Collect`.
         | 
| 163 | 
            +
                        rescale (bool): Whether rescale back to original shape.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    Returns:
         | 
| 166 | 
            +
                        Tensor: The output depth map.
         | 
| 167 | 
            +
                    """
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    assert mode in ["slide", "whole"]
         | 
| 170 | 
            +
                    ori_shape = img_meta[0]["ori_shape"]
         | 
| 171 | 
            +
                    assert all(_["ori_shape"] == ori_shape for _ in img_meta)
         | 
| 172 | 
            +
                    if mode == "slide":
         | 
| 173 | 
            +
                        depth_pred = self.slide_inference(img, img_meta, rescale)
         | 
| 174 | 
            +
                    else:
         | 
| 175 | 
            +
                        depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
         | 
| 176 | 
            +
                    output = depth_pred
         | 
| 177 | 
            +
                    flip = img_meta[0]["flip"]
         | 
| 178 | 
            +
                    if flip:
         | 
| 179 | 
            +
                        flip_direction = img_meta[0]["flip_direction"]
         | 
| 180 | 
            +
                        assert flip_direction in ["horizontal", "vertical"]
         | 
| 181 | 
            +
                        if flip_direction == "horizontal":
         | 
| 182 | 
            +
                            output = output.flip(dims=(3,))
         | 
| 183 | 
            +
                        elif flip_direction == "vertical":
         | 
| 184 | 
            +
                            output = output.flip(dims=(2,))
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    return output
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def simple_test(self, img, img_meta, rescale=True):
         | 
| 189 | 
            +
                    """Simple test with single image."""
         | 
| 190 | 
            +
                    depth_pred = self.inference(img, img_meta, rescale)
         | 
| 191 | 
            +
                    if torch.onnx.is_in_onnx_export():
         | 
| 192 | 
            +
                        # our inference backend only support 4D output
         | 
| 193 | 
            +
                        depth_pred = depth_pred.unsqueeze(0)
         | 
| 194 | 
            +
                        return depth_pred
         | 
| 195 | 
            +
                    depth_pred = depth_pred.cpu().numpy()
         | 
| 196 | 
            +
                    # unravel batch dim
         | 
| 197 | 
            +
                    depth_pred = list(depth_pred)
         | 
| 198 | 
            +
                    return depth_pred
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                def aug_test(self, imgs, img_metas, rescale=True):
         | 
| 201 | 
            +
                    """Test with augmentations.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    Only rescale=True is supported.
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    # aug_test rescale all imgs back to ori_shape for now
         | 
| 206 | 
            +
                    assert rescale
         | 
| 207 | 
            +
                    # to save memory, we get augmented depth logit inplace
         | 
| 208 | 
            +
                    depth_pred = self.inference(imgs[0], img_metas[0], rescale)
         | 
| 209 | 
            +
                    for i in range(1, len(imgs)):
         | 
| 210 | 
            +
                        cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
         | 
| 211 | 
            +
                        depth_pred += cur_depth_pred
         | 
| 212 | 
            +
                    depth_pred /= len(imgs)
         | 
| 213 | 
            +
                    depth_pred = depth_pred.cpu().numpy()
         | 
| 214 | 
            +
                    # unravel batch dim
         | 
| 215 | 
            +
                    depth_pred = list(depth_pred)
         | 
| 216 | 
            +
                    return depth_pred
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                def forward_test(self, imgs, img_metas, **kwargs):
         | 
| 219 | 
            +
                    """
         | 
| 220 | 
            +
                    Args:
         | 
| 221 | 
            +
                        imgs (List[Tensor]): the outer list indicates test-time
         | 
| 222 | 
            +
                            augmentations and inner Tensor should have a shape NxCxHxW,
         | 
| 223 | 
            +
                            which contains all images in the batch.
         | 
| 224 | 
            +
                        img_metas (List[List[dict]]): the outer list indicates test-time
         | 
| 225 | 
            +
                            augs (multiscale, flip, etc.) and the inner list indicates
         | 
| 226 | 
            +
                            images in a batch.
         | 
| 227 | 
            +
                    """
         | 
| 228 | 
            +
                    for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
         | 
| 229 | 
            +
                        if not isinstance(var, list):
         | 
| 230 | 
            +
                            raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
         | 
| 231 | 
            +
                    num_augs = len(imgs)
         | 
| 232 | 
            +
                    if num_augs != len(img_metas):
         | 
| 233 | 
            +
                        raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
         | 
| 234 | 
            +
                    # all images in the same aug batch all of the same ori_shape and pad
         | 
| 235 | 
            +
                    # shape
         | 
| 236 | 
            +
                    for img_meta in img_metas:
         | 
| 237 | 
            +
                        ori_shapes = [_["ori_shape"] for _ in img_meta]
         | 
| 238 | 
            +
                        assert all(shape == ori_shapes[0] for shape in ori_shapes)
         | 
| 239 | 
            +
                        img_shapes = [_["img_shape"] for _ in img_meta]
         | 
| 240 | 
            +
                        assert all(shape == img_shapes[0] for shape in img_shapes)
         | 
| 241 | 
            +
                        pad_shapes = [_["pad_shape"] for _ in img_meta]
         | 
| 242 | 
            +
                        assert all(shape == pad_shapes[0] for shape in pad_shapes)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if num_augs == 1:
         | 
| 245 | 
            +
                        return self.simple_test(imgs[0], img_metas[0], **kwargs)
         | 
| 246 | 
            +
                    else:
         | 
| 247 | 
            +
                        return self.aug_test(imgs, img_metas, **kwargs)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                def forward(self, img, img_metas, return_loss=True, **kwargs):
         | 
| 250 | 
            +
                    """Calls either :func:`forward_train` or :func:`forward_test` depending
         | 
| 251 | 
            +
                    on whether ``return_loss`` is ``True``.
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    Note this setting will change the expected inputs. When
         | 
| 254 | 
            +
                    ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
         | 
| 255 | 
            +
                    and List[dict]), and when ``resturn_loss=False``, img and img_meta
         | 
| 256 | 
            +
                    should be double nested (i.e.  List[Tensor], List[List[dict]]), with
         | 
| 257 | 
            +
                    the outer list indicating test time augmentations.
         | 
| 258 | 
            +
                    """
         | 
| 259 | 
            +
                    if return_loss:
         | 
| 260 | 
            +
                        return self.forward_train(img, img_metas, **kwargs)
         | 
| 261 | 
            +
                    else:
         | 
| 262 | 
            +
                        return self.forward_test(img, img_metas, **kwargs)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def train_step(self, data_batch, optimizer, **kwargs):
         | 
| 265 | 
            +
                    """The iteration step during training.
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    This method defines an iteration step during training, except for the
         | 
| 268 | 
            +
                    back propagation and optimizer updating, which are done in an optimizer
         | 
| 269 | 
            +
                    hook. Note that in some complicated cases or models, the whole process
         | 
| 270 | 
            +
                    including back propagation and optimizer updating is also defined in
         | 
| 271 | 
            +
                    this method, such as GAN.
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    Args:
         | 
| 274 | 
            +
                        data (dict): The output of dataloader.
         | 
| 275 | 
            +
                        optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
         | 
| 276 | 
            +
                            runner is passed to ``train_step()``. This argument is unused
         | 
| 277 | 
            +
                            and reserved.
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    Returns:
         | 
| 280 | 
            +
                        dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
         | 
| 281 | 
            +
                            ``num_samples``.
         | 
| 282 | 
            +
                            ``loss`` is a tensor for back propagation, which can be a
         | 
| 283 | 
            +
                            weighted sum of multiple losses.
         | 
| 284 | 
            +
                            ``log_vars`` contains all the variables to be sent to the
         | 
| 285 | 
            +
                            logger.
         | 
| 286 | 
            +
                            ``num_samples`` indicates the batch size (when the model is
         | 
| 287 | 
            +
                            DDP, it means the batch size on each GPU), which is used for
         | 
| 288 | 
            +
                            averaging the logs.
         | 
| 289 | 
            +
                    """
         | 
| 290 | 
            +
                    losses = self(**data_batch)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    # split losses and images
         | 
| 293 | 
            +
                    real_losses = {}
         | 
| 294 | 
            +
                    log_imgs = {}
         | 
| 295 | 
            +
                    for k, v in losses.items():
         | 
| 296 | 
            +
                        if "img" in k:
         | 
| 297 | 
            +
                            log_imgs[k] = v
         | 
| 298 | 
            +
                        else:
         | 
| 299 | 
            +
                            real_losses[k] = v
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    loss, log_vars = self._parse_losses(real_losses)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    return outputs
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                def val_step(self, data_batch, **kwargs):
         | 
| 308 | 
            +
                    """The iteration step during validation.
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    This method shares the same signature as :func:`train_step`, but used
         | 
| 311 | 
            +
                    during val epochs. Note that the evaluation after training epochs is
         | 
| 312 | 
            +
                    not implemented with this method, but an evaluation hook.
         | 
| 313 | 
            +
                    """
         | 
| 314 | 
            +
                    output = self(**data_batch, **kwargs)
         | 
| 315 | 
            +
                    return output
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                @staticmethod
         | 
| 318 | 
            +
                def _parse_losses(losses):
         | 
| 319 | 
            +
                    import torch.distributed as dist
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    """Parse the raw outputs (losses) of the network.
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    Args:
         | 
| 324 | 
            +
                        losses (dict): Raw output of the network, which usually contain
         | 
| 325 | 
            +
                            losses and other necessary information.
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    Returns:
         | 
| 328 | 
            +
                        tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
         | 
| 329 | 
            +
                            which may be a weighted sum of all losses, log_vars contains
         | 
| 330 | 
            +
                            all the variables to be sent to the logger.
         | 
| 331 | 
            +
                    """
         | 
| 332 | 
            +
                    log_vars = OrderedDict()
         | 
| 333 | 
            +
                    for loss_name, loss_value in losses.items():
         | 
| 334 | 
            +
                        if isinstance(loss_value, torch.Tensor):
         | 
| 335 | 
            +
                            log_vars[loss_name] = loss_value.mean()
         | 
| 336 | 
            +
                        elif isinstance(loss_value, list):
         | 
| 337 | 
            +
                            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
         | 
| 338 | 
            +
                        else:
         | 
| 339 | 
            +
                            raise TypeError(f"{loss_name} is not a tensor or list of tensors")
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    log_vars["loss"] = loss
         | 
| 344 | 
            +
                    for loss_name, loss_value in log_vars.items():
         | 
| 345 | 
            +
                        # reduce loss when distributed training
         | 
| 346 | 
            +
                        if dist.is_available() and dist.is_initialized():
         | 
| 347 | 
            +
                            loss_value = loss_value.data.clone()
         | 
| 348 | 
            +
                            dist.all_reduce(loss_value.div_(dist.get_world_size()))
         | 
| 349 | 
            +
                        log_vars[loss_name] = loss_value.item()
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    return loss, log_vars
         | 
    	
        LHM/models/encoders/dinov2/hub/depth/ops.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import warnings
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
         | 
| 12 | 
            +
                if warning:
         | 
| 13 | 
            +
                    if size is not None and align_corners:
         | 
| 14 | 
            +
                        input_h, input_w = tuple(int(x) for x in input.shape[2:])
         | 
| 15 | 
            +
                        output_h, output_w = tuple(int(x) for x in size)
         | 
| 16 | 
            +
                        if output_h > input_h or output_w > output_h:
         | 
| 17 | 
            +
                            if (
         | 
| 18 | 
            +
                                (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
         | 
| 19 | 
            +
                                and (output_h - 1) % (input_h - 1)
         | 
| 20 | 
            +
                                and (output_w - 1) % (input_w - 1)
         | 
| 21 | 
            +
                            ):
         | 
| 22 | 
            +
                                warnings.warn(
         | 
| 23 | 
            +
                                    f"When align_corners={align_corners}, "
         | 
| 24 | 
            +
                                    "the output would more aligned if "
         | 
| 25 | 
            +
                                    f"input size {(input_h, input_w)} is `x+1` and "
         | 
| 26 | 
            +
                                    f"out size {(output_h, output_w)} is `nx+1`"
         | 
| 27 | 
            +
                                )
         | 
| 28 | 
            +
                return F.interpolate(input, size, scale_factor, mode, align_corners)
         | 
    	
        LHM/models/encoders/dinov2/hub/depthers.py
    ADDED
    
    | @@ -0,0 +1,246 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the Apache License, Version 2.0
         | 
| 4 | 
            +
            # found in the LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from enum import Enum
         | 
| 7 | 
            +
            from functools import partial
         | 
| 8 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .backbones import _make_dinov2_model
         | 
| 13 | 
            +
            from .depth import BNHead, DepthEncoderDecoder, DPTHead
         | 
| 14 | 
            +
            from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class Weights(Enum):
         | 
| 18 | 
            +
                NYU = "NYU"
         | 
| 19 | 
            +
                KITTI = "KITTI"
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
         | 
| 23 | 
            +
                if not pretrained:  # Default
         | 
| 24 | 
            +
                    return (0.001, 10.0)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                # Pretrained, set according to the training dataset for the provided weights
         | 
| 27 | 
            +
                if weights == Weights.KITTI:
         | 
| 28 | 
            +
                    return (0.001, 80.0)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if weights == Weights.NYU:
         | 
| 31 | 
            +
                    return (0.001, 10.0)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                return (0.001, 10.0)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def _make_dinov2_linear_depth_head(
         | 
| 37 | 
            +
                *,
         | 
| 38 | 
            +
                embed_dim: int,
         | 
| 39 | 
            +
                layers: int,
         | 
| 40 | 
            +
                min_depth: float,
         | 
| 41 | 
            +
                max_depth: float,
         | 
| 42 | 
            +
                **kwargs,
         | 
| 43 | 
            +
            ):
         | 
| 44 | 
            +
                if layers not in (1, 4):
         | 
| 45 | 
            +
                    raise AssertionError(f"Unsupported number of layers: {layers}")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                if layers == 1:
         | 
| 48 | 
            +
                    in_index = [0]
         | 
| 49 | 
            +
                else:
         | 
| 50 | 
            +
                    assert layers == 4
         | 
| 51 | 
            +
                    in_index = [0, 1, 2, 3]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                return BNHead(
         | 
| 54 | 
            +
                    classify=True,
         | 
| 55 | 
            +
                    n_bins=256,
         | 
| 56 | 
            +
                    bins_strategy="UD",
         | 
| 57 | 
            +
                    norm_strategy="linear",
         | 
| 58 | 
            +
                    upsample=4,
         | 
| 59 | 
            +
                    in_channels=[embed_dim] * len(in_index),
         | 
| 60 | 
            +
                    in_index=in_index,
         | 
| 61 | 
            +
                    input_transform="resize_concat",
         | 
| 62 | 
            +
                    channels=embed_dim * len(in_index) * 2,
         | 
| 63 | 
            +
                    align_corners=False,
         | 
| 64 | 
            +
                    min_depth=0.001,
         | 
| 65 | 
            +
                    max_depth=80,
         | 
| 66 | 
            +
                    loss_decode=(),
         | 
| 67 | 
            +
                )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def _make_dinov2_linear_depther(
         | 
| 71 | 
            +
                *,
         | 
| 72 | 
            +
                arch_name: str = "vit_large",
         | 
| 73 | 
            +
                layers: int = 4,
         | 
| 74 | 
            +
                pretrained: bool = True,
         | 
| 75 | 
            +
                weights: Union[Weights, str] = Weights.NYU,
         | 
| 76 | 
            +
                depth_range: Optional[Tuple[float, float]] = None,
         | 
| 77 | 
            +
                **kwargs,
         | 
| 78 | 
            +
            ):
         | 
| 79 | 
            +
                if layers not in (1, 4):
         | 
| 80 | 
            +
                    raise AssertionError(f"Unsupported number of layers: {layers}")
         | 
| 81 | 
            +
                if isinstance(weights, str):
         | 
| 82 | 
            +
                    try:
         | 
| 83 | 
            +
                        weights = Weights[weights]
         | 
| 84 | 
            +
                    except KeyError:
         | 
| 85 | 
            +
                        raise AssertionError(f"Unsupported weights: {weights}")
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                if depth_range is None:
         | 
| 88 | 
            +
                    depth_range = _get_depth_range(pretrained, weights)
         | 
| 89 | 
            +
                min_depth, max_depth = depth_range
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                embed_dim = backbone.embed_dim
         | 
| 94 | 
            +
                patch_size = backbone.patch_size
         | 
| 95 | 
            +
                model_name = _make_dinov2_model_name(arch_name, patch_size)
         | 
| 96 | 
            +
                linear_depth_head = _make_dinov2_linear_depth_head(
         | 
| 97 | 
            +
                    embed_dim=embed_dim,
         | 
| 98 | 
            +
                    layers=layers,
         | 
| 99 | 
            +
                    min_depth=min_depth,
         | 
| 100 | 
            +
                    max_depth=max_depth,
         | 
| 101 | 
            +
                )
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                layer_count = {
         | 
| 104 | 
            +
                    "vit_small": 12,
         | 
| 105 | 
            +
                    "vit_base": 12,
         | 
| 106 | 
            +
                    "vit_large": 24,
         | 
| 107 | 
            +
                    "vit_giant2": 40,
         | 
| 108 | 
            +
                }[arch_name]
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                if layers == 4:
         | 
| 111 | 
            +
                    out_index = {
         | 
| 112 | 
            +
                        "vit_small": [2, 5, 8, 11],
         | 
| 113 | 
            +
                        "vit_base": [2, 5, 8, 11],
         | 
| 114 | 
            +
                        "vit_large": [4, 11, 17, 23],
         | 
| 115 | 
            +
                        "vit_giant2": [9, 19, 29, 39],
         | 
| 116 | 
            +
                    }[arch_name]
         | 
| 117 | 
            +
                else:
         | 
| 118 | 
            +
                    assert layers == 1
         | 
| 119 | 
            +
                    out_index = [layer_count - 1]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
         | 
| 122 | 
            +
                model.backbone.forward = partial(
         | 
| 123 | 
            +
                    backbone.get_intermediate_layers,
         | 
| 124 | 
            +
                    n=out_index,
         | 
| 125 | 
            +
                    reshape=True,
         | 
| 126 | 
            +
                    return_class_token=True,
         | 
| 127 | 
            +
                    norm=False,
         | 
| 128 | 
            +
                )
         | 
| 129 | 
            +
                model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                if pretrained:
         | 
| 132 | 
            +
                    layers_str = str(layers) if layers == 4 else ""
         | 
| 133 | 
            +
                    weights_str = weights.value.lower()
         | 
| 134 | 
            +
                    url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
         | 
| 135 | 
            +
                    checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
         | 
| 136 | 
            +
                    if "state_dict" in checkpoint:
         | 
| 137 | 
            +
                        state_dict = checkpoint["state_dict"]
         | 
| 138 | 
            +
                    model.load_state_dict(state_dict, strict=False)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                return model
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 144 | 
            +
                return _make_dinov2_linear_depther(
         | 
| 145 | 
            +
                    arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
         | 
| 146 | 
            +
                )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 150 | 
            +
                return _make_dinov2_linear_depther(
         | 
| 151 | 
            +
                    arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
         | 
| 152 | 
            +
                )
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 156 | 
            +
                return _make_dinov2_linear_depther(
         | 
| 157 | 
            +
                    arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
         | 
| 158 | 
            +
                )
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 162 | 
            +
                return _make_dinov2_linear_depther(
         | 
| 163 | 
            +
                    arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
         | 
| 164 | 
            +
                )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
         | 
| 168 | 
            +
                return DPTHead(
         | 
| 169 | 
            +
                    in_channels=[embed_dim] * 4,
         | 
| 170 | 
            +
                    channels=256,
         | 
| 171 | 
            +
                    embed_dims=embed_dim,
         | 
| 172 | 
            +
                    post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
         | 
| 173 | 
            +
                    readout_type="project",
         | 
| 174 | 
            +
                    min_depth=min_depth,
         | 
| 175 | 
            +
                    max_depth=max_depth,
         | 
| 176 | 
            +
                    loss_decode=(),
         | 
| 177 | 
            +
                )
         | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            def _make_dinov2_dpt_depther(
         | 
| 181 | 
            +
                *,
         | 
| 182 | 
            +
                arch_name: str = "vit_large",
         | 
| 183 | 
            +
                pretrained: bool = True,
         | 
| 184 | 
            +
                weights: Union[Weights, str] = Weights.NYU,
         | 
| 185 | 
            +
                depth_range: Optional[Tuple[float, float]] = None,
         | 
| 186 | 
            +
                **kwargs,
         | 
| 187 | 
            +
            ):
         | 
| 188 | 
            +
                if isinstance(weights, str):
         | 
| 189 | 
            +
                    try:
         | 
| 190 | 
            +
                        weights = Weights[weights]
         | 
| 191 | 
            +
                    except KeyError:
         | 
| 192 | 
            +
                        raise AssertionError(f"Unsupported weights: {weights}")
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                if depth_range is None:
         | 
| 195 | 
            +
                    depth_range = _get_depth_range(pretrained, weights)
         | 
| 196 | 
            +
                min_depth, max_depth = depth_range
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
         | 
| 201 | 
            +
                dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                out_index = {
         | 
| 204 | 
            +
                    "vit_small": [2, 5, 8, 11],
         | 
| 205 | 
            +
                    "vit_base": [2, 5, 8, 11],
         | 
| 206 | 
            +
                    "vit_large": [4, 11, 17, 23],
         | 
| 207 | 
            +
                    "vit_giant2": [9, 19, 29, 39],
         | 
| 208 | 
            +
                }[arch_name]
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
         | 
| 211 | 
            +
                model.backbone.forward = partial(
         | 
| 212 | 
            +
                    backbone.get_intermediate_layers,
         | 
| 213 | 
            +
                    n=out_index,
         | 
| 214 | 
            +
                    reshape=True,
         | 
| 215 | 
            +
                    return_class_token=True,
         | 
| 216 | 
            +
                    norm=False,
         | 
| 217 | 
            +
                )
         | 
| 218 | 
            +
                model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                if pretrained:
         | 
| 221 | 
            +
                    weights_str = weights.value.lower()
         | 
| 222 | 
            +
                    url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
         | 
| 223 | 
            +
                    checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
         | 
| 224 | 
            +
                    if "state_dict" in checkpoint:
         | 
| 225 | 
            +
                        state_dict = checkpoint["state_dict"]
         | 
| 226 | 
            +
                    model.load_state_dict(state_dict, strict=False)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                return model
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 232 | 
            +
                return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
             | 
| 235 | 
            +
            def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 236 | 
            +
                return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 240 | 
            +
                return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
         | 
| 244 | 
            +
                return _make_dinov2_dpt_depther(
         | 
| 245 | 
            +
                    arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
         | 
| 246 | 
            +
                )
         |