Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| from enum import Enum | |
| from functools import partial | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from .backbones import _make_dinov2_model | |
| from .depth import BNHead, DepthEncoderDecoder, DPTHead | |
| from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding | |
| class Weights(Enum): | |
| NYU = "NYU" | |
| KITTI = "KITTI" | |
| def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: | |
| if not pretrained: # Default | |
| return (0.001, 10.0) | |
| # Pretrained, set according to the training dataset for the provided weights | |
| if weights == Weights.KITTI: | |
| return (0.001, 80.0) | |
| if weights == Weights.NYU: | |
| return (0.001, 10.0) | |
| return (0.001, 10.0) | |
| def _make_dinov2_linear_depth_head( | |
| *, | |
| embed_dim: int, | |
| layers: int, | |
| min_depth: float, | |
| max_depth: float, | |
| **kwargs, | |
| ): | |
| if layers not in (1, 4): | |
| raise AssertionError(f"Unsupported number of layers: {layers}") | |
| if layers == 1: | |
| in_index = [0] | |
| else: | |
| assert layers == 4 | |
| in_index = [0, 1, 2, 3] | |
| return BNHead( | |
| classify=True, | |
| n_bins=256, | |
| bins_strategy="UD", | |
| norm_strategy="linear", | |
| upsample=4, | |
| in_channels=[embed_dim] * len(in_index), | |
| in_index=in_index, | |
| input_transform="resize_concat", | |
| channels=embed_dim * len(in_index) * 2, | |
| align_corners=False, | |
| min_depth=0.001, | |
| max_depth=80, | |
| loss_decode=(), | |
| ) | |
| def _make_dinov2_linear_depther( | |
| *, | |
| arch_name: str = "vit_large", | |
| layers: int = 4, | |
| pretrained: bool = True, | |
| weights: Union[Weights, str] = Weights.NYU, | |
| depth_range: Optional[Tuple[float, float]] = None, | |
| **kwargs, | |
| ): | |
| if layers not in (1, 4): | |
| raise AssertionError(f"Unsupported number of layers: {layers}") | |
| if isinstance(weights, str): | |
| try: | |
| weights = Weights[weights] | |
| except KeyError: | |
| raise AssertionError(f"Unsupported weights: {weights}") | |
| if depth_range is None: | |
| depth_range = _get_depth_range(pretrained, weights) | |
| min_depth, max_depth = depth_range | |
| backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) | |
| embed_dim = backbone.embed_dim | |
| patch_size = backbone.patch_size | |
| model_name = _make_dinov2_model_name(arch_name, patch_size) | |
| linear_depth_head = _make_dinov2_linear_depth_head( | |
| embed_dim=embed_dim, | |
| layers=layers, | |
| min_depth=min_depth, | |
| max_depth=max_depth, | |
| ) | |
| layer_count = { | |
| "vit_small": 12, | |
| "vit_base": 12, | |
| "vit_large": 24, | |
| "vit_giant2": 40, | |
| }[arch_name] | |
| if layers == 4: | |
| out_index = { | |
| "vit_small": [2, 5, 8, 11], | |
| "vit_base": [2, 5, 8, 11], | |
| "vit_large": [4, 11, 17, 23], | |
| "vit_giant2": [9, 19, 29, 39], | |
| }[arch_name] | |
| else: | |
| assert layers == 1 | |
| out_index = [layer_count - 1] | |
| model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) | |
| model.backbone.forward = partial( | |
| backbone.get_intermediate_layers, | |
| n=out_index, | |
| reshape=True, | |
| return_class_token=True, | |
| norm=False, | |
| ) | |
| model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) | |
| if pretrained: | |
| layers_str = str(layers) if layers == 4 else "" | |
| weights_str = weights.value.lower() | |
| url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" | |
| checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") | |
| if "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_linear_depther( | |
| arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs | |
| ) | |
| def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_linear_depther( | |
| arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs | |
| ) | |
| def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_linear_depther( | |
| arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs | |
| ) | |
| def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_linear_depther( | |
| arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs | |
| ) | |
| def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): | |
| return DPTHead( | |
| in_channels=[embed_dim] * 4, | |
| channels=256, | |
| embed_dims=embed_dim, | |
| post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], | |
| readout_type="project", | |
| min_depth=min_depth, | |
| max_depth=max_depth, | |
| loss_decode=(), | |
| ) | |
| def _make_dinov2_dpt_depther( | |
| *, | |
| arch_name: str = "vit_large", | |
| pretrained: bool = True, | |
| weights: Union[Weights, str] = Weights.NYU, | |
| depth_range: Optional[Tuple[float, float]] = None, | |
| **kwargs, | |
| ): | |
| if isinstance(weights, str): | |
| try: | |
| weights = Weights[weights] | |
| except KeyError: | |
| raise AssertionError(f"Unsupported weights: {weights}") | |
| if depth_range is None: | |
| depth_range = _get_depth_range(pretrained, weights) | |
| min_depth, max_depth = depth_range | |
| backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) | |
| model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) | |
| dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) | |
| out_index = { | |
| "vit_small": [2, 5, 8, 11], | |
| "vit_base": [2, 5, 8, 11], | |
| "vit_large": [4, 11, 17, 23], | |
| "vit_giant2": [9, 19, 29, 39], | |
| }[arch_name] | |
| model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) | |
| model.backbone.forward = partial( | |
| backbone.get_intermediate_layers, | |
| n=out_index, | |
| reshape=True, | |
| return_class_token=True, | |
| norm=False, | |
| ) | |
| model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) | |
| if pretrained: | |
| weights_str = weights.value.lower() | |
| url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" | |
| checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") | |
| if "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) | |
| def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) | |
| def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) | |
| def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): | |
| return _make_dinov2_dpt_depther( | |
| arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs | |
| ) | |