Upload 61 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- hf_space/third_party/matanyone/__init__.py +2 -2
- hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/config/__init__.py +0 -0
- hf_space/third_party/matanyone/config/eval_matanyone_config.yaml +47 -0
- hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml +22 -0
- hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml +22 -0
- hf_space/third_party/matanyone/config/model/base.yaml +58 -0
- hf_space/third_party/matanyone/inference/__init__.py +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/inference/image_feature_store.py +56 -0
- hf_space/third_party/matanyone/inference/inference_core.py +545 -0
- hf_space/third_party/matanyone/inference/kv_memory_store.py +348 -0
- hf_space/third_party/matanyone/inference/memory_manager.py +453 -0
- hf_space/third_party/matanyone/inference/object_info.py +24 -0
- hf_space/third_party/matanyone/inference/object_manager.py +149 -0
- hf_space/third_party/matanyone/inference/utils/__init__.py +0 -0
- hf_space/third_party/matanyone/inference/utils/args_utils.py +30 -0
- hf_space/third_party/matanyone/model/__init__.py +0 -0
- hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/aux_modules.py +93 -0
- hf_space/third_party/matanyone/model/big_modules.py +365 -0
- hf_space/third_party/matanyone/model/channel_attn.py +39 -0
- hf_space/third_party/matanyone/model/group_modules.py +126 -0
- hf_space/third_party/matanyone/model/matanyone.py +333 -0
- hf_space/third_party/matanyone/model/modules.py +149 -0
- hf_space/third_party/matanyone/model/transformer/__init__.py +0 -0
- hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/transformer/object_summarizer.py +89 -0
- hf_space/third_party/matanyone/model/transformer/object_transformer.py +206 -0
- hf_space/third_party/matanyone/model/transformer/positional_encoding.py +108 -0
- hf_space/third_party/matanyone/model/transformer/transformer_layers.py +161 -0
- hf_space/third_party/matanyone/model/utils/__init__.py +0 -0
- hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc +0 -0
- hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc +0 -0
hf_space/third_party/matanyone/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
| 1 |
+
from matanyone.inference.inference_core import InferenceCore
|
| 2 |
+
from matanyone.model.matanyone import MatAnyone
|
hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (324 Bytes). View file
|
|
|
hf_space/third_party/matanyone/config/__init__.py
ADDED
|
File without changes
|
hf_space/third_party/matanyone/config/eval_matanyone_config.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- model: base
|
| 4 |
+
- override hydra/job_logging: custom-no-rank.yaml
|
| 5 |
+
|
| 6 |
+
hydra:
|
| 7 |
+
run:
|
| 8 |
+
dir: ../output/${exp_id}/${dataset}
|
| 9 |
+
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
| 10 |
+
|
| 11 |
+
amp: False
|
| 12 |
+
weights: pretrained_models/matanyone.pth # default (can be modified from outside)
|
| 13 |
+
output_dir: null # defaults to run_dir; specify this to override
|
| 14 |
+
flip_aug: False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# maximum shortest side of the input; -1 means no resizing
|
| 18 |
+
# With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
|
| 19 |
+
# this parameter is added for the sole purpose for the GUI in the current codebase
|
| 20 |
+
# InferenceCore will downsize the input and restore the output to the original size if needed
|
| 21 |
+
# if you are using this code for some other project, you can also utilize this parameter
|
| 22 |
+
max_internal_size: -1
|
| 23 |
+
|
| 24 |
+
# these parameters, when set, override the dataset's default; useful for debugging
|
| 25 |
+
save_all: True
|
| 26 |
+
use_all_masks: False
|
| 27 |
+
use_long_term: False
|
| 28 |
+
mem_every: 5
|
| 29 |
+
|
| 30 |
+
# only relevant when long_term is not enabled
|
| 31 |
+
max_mem_frames: 5
|
| 32 |
+
|
| 33 |
+
# only relevant when long_term is enabled
|
| 34 |
+
long_term:
|
| 35 |
+
count_usage: True
|
| 36 |
+
max_mem_frames: 10
|
| 37 |
+
min_mem_frames: 5
|
| 38 |
+
num_prototypes: 128
|
| 39 |
+
max_num_tokens: 10000
|
| 40 |
+
buffer_tokens: 2000
|
| 41 |
+
|
| 42 |
+
top_k: 30
|
| 43 |
+
stagger_updates: 5
|
| 44 |
+
chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
|
| 45 |
+
save_scores: False
|
| 46 |
+
save_aux: False
|
| 47 |
+
visualize: False
|
hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python logging configuration for tasks
|
| 2 |
+
version: 1
|
| 3 |
+
formatters:
|
| 4 |
+
simple:
|
| 5 |
+
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
| 6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
| 7 |
+
handlers:
|
| 8 |
+
console:
|
| 9 |
+
class: logging.StreamHandler
|
| 10 |
+
formatter: simple
|
| 11 |
+
stream: ext://sys.stdout
|
| 12 |
+
file:
|
| 13 |
+
class: logging.FileHandler
|
| 14 |
+
formatter: simple
|
| 15 |
+
# absolute file path
|
| 16 |
+
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
|
| 17 |
+
mode: w
|
| 18 |
+
root:
|
| 19 |
+
level: INFO
|
| 20 |
+
handlers: [console, file]
|
| 21 |
+
|
| 22 |
+
disable_existing_loggers: false
|
hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python logging configuration for tasks
|
| 2 |
+
version: 1
|
| 3 |
+
formatters:
|
| 4 |
+
simple:
|
| 5 |
+
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
| 6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
| 7 |
+
handlers:
|
| 8 |
+
console:
|
| 9 |
+
class: logging.StreamHandler
|
| 10 |
+
formatter: simple
|
| 11 |
+
stream: ext://sys.stdout
|
| 12 |
+
file:
|
| 13 |
+
class: logging.FileHandler
|
| 14 |
+
formatter: simple
|
| 15 |
+
# absolute file path
|
| 16 |
+
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
| 17 |
+
mode: w
|
| 18 |
+
root:
|
| 19 |
+
level: INFO
|
| 20 |
+
handlers: [console, file]
|
| 21 |
+
|
| 22 |
+
disable_existing_loggers: false
|
hf_space/third_party/matanyone/config/model/base.yaml
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pixel_mean: [0.485, 0.456, 0.406]
|
| 2 |
+
pixel_std: [0.229, 0.224, 0.225]
|
| 3 |
+
|
| 4 |
+
pixel_dim: 256
|
| 5 |
+
key_dim: 64
|
| 6 |
+
value_dim: 256
|
| 7 |
+
sensory_dim: 256
|
| 8 |
+
embed_dim: 256
|
| 9 |
+
|
| 10 |
+
pixel_encoder:
|
| 11 |
+
type: resnet50
|
| 12 |
+
ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
|
| 13 |
+
|
| 14 |
+
mask_encoder:
|
| 15 |
+
type: resnet18
|
| 16 |
+
final_dim: 256
|
| 17 |
+
|
| 18 |
+
pixel_pe_scale: 32
|
| 19 |
+
pixel_pe_temperature: 128
|
| 20 |
+
|
| 21 |
+
object_transformer:
|
| 22 |
+
embed_dim: ${model.embed_dim}
|
| 23 |
+
ff_dim: 2048
|
| 24 |
+
num_heads: 8
|
| 25 |
+
num_blocks: 3
|
| 26 |
+
num_queries: 16
|
| 27 |
+
read_from_pixel:
|
| 28 |
+
input_norm: False
|
| 29 |
+
input_add_pe: False
|
| 30 |
+
add_pe_to_qkv: [True, True, False]
|
| 31 |
+
read_from_past:
|
| 32 |
+
add_pe_to_qkv: [True, True, False]
|
| 33 |
+
read_from_memory:
|
| 34 |
+
add_pe_to_qkv: [True, True, False]
|
| 35 |
+
read_from_query:
|
| 36 |
+
add_pe_to_qkv: [True, True, False]
|
| 37 |
+
output_norm: False
|
| 38 |
+
query_self_attention:
|
| 39 |
+
add_pe_to_qkv: [True, True, False]
|
| 40 |
+
pixel_self_attention:
|
| 41 |
+
add_pe_to_qkv: [True, True, False]
|
| 42 |
+
|
| 43 |
+
object_summarizer:
|
| 44 |
+
embed_dim: ${model.object_transformer.embed_dim}
|
| 45 |
+
num_summaries: ${model.object_transformer.num_queries}
|
| 46 |
+
add_pe: True
|
| 47 |
+
|
| 48 |
+
aux_loss:
|
| 49 |
+
sensory:
|
| 50 |
+
enabled: True
|
| 51 |
+
weight: 0.01
|
| 52 |
+
query:
|
| 53 |
+
enabled: True
|
| 54 |
+
weight: 0.01
|
| 55 |
+
|
| 56 |
+
mask_decoder:
|
| 57 |
+
# first value must equal embed_dim
|
| 58 |
+
up_dims: [256, 128, 128, 64, 16]
|
hf_space/third_party/matanyone/inference/__init__.py
ADDED
|
File without changes
|
hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc
ADDED
|
Binary file (4.71 kB). View file
|
|
|
hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc
ADDED
|
Binary file (22.5 kB). View file
|
|
|
hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc
ADDED
|
Binary file (8.07 kB). View file
|
|
|
hf_space/third_party/matanyone/inference/image_feature_store.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
import torch
|
| 4 |
+
from matanyone.model.matanyone import MatAnyone
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ImageFeatureStore:
|
| 8 |
+
"""
|
| 9 |
+
A cache for image features.
|
| 10 |
+
These features might be reused at different parts of the inference pipeline.
|
| 11 |
+
This class provide an interface for reusing these features.
|
| 12 |
+
It is the user's responsibility to delete redundant features.
|
| 13 |
+
|
| 14 |
+
Feature of a frame should be associated with a unique index -- typically the frame id.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, network: MatAnyone, no_warning: bool = False):
|
| 17 |
+
self.network = network
|
| 18 |
+
self._store = {}
|
| 19 |
+
self.no_warning = no_warning
|
| 20 |
+
|
| 21 |
+
def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
|
| 22 |
+
ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
|
| 23 |
+
key, shrinkage, selection = self.network.transform_key(ms_features[0])
|
| 24 |
+
self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
|
| 25 |
+
|
| 26 |
+
def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
|
| 27 |
+
seq_length = images.shape[0]
|
| 28 |
+
ms_features, pix_feat = self.network.encode_image(images, seq_length)
|
| 29 |
+
key, shrinkage, selection = self.network.transform_key(ms_features[0])
|
| 30 |
+
for index in range(seq_length):
|
| 31 |
+
self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
|
| 32 |
+
|
| 33 |
+
def get_features(self, index: int,
|
| 34 |
+
image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
|
| 35 |
+
if index not in self._store:
|
| 36 |
+
self._encode_feature(index, image, last_feats)
|
| 37 |
+
|
| 38 |
+
return self._store[index][:2]
|
| 39 |
+
|
| 40 |
+
def get_key(self, index: int,
|
| 41 |
+
image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
| 42 |
+
if index not in self._store:
|
| 43 |
+
self._encode_feature(index, image, last_feats)
|
| 44 |
+
|
| 45 |
+
return self._store[index][2:]
|
| 46 |
+
|
| 47 |
+
def delete(self, index: int) -> None:
|
| 48 |
+
if index in self._store:
|
| 49 |
+
del self._store[index]
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self._store)
|
| 53 |
+
|
| 54 |
+
def __del__(self):
|
| 55 |
+
if len(self._store) > 0 and not self.no_warning:
|
| 56 |
+
warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
|
hf_space/third_party/matanyone/inference/inference_core.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
from typing import List, Optional, Iterable, Union,Tuple
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
import imageio
|
| 9 |
+
import tempfile
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from matanyone.inference.memory_manager import MemoryManager
|
| 16 |
+
from matanyone.inference.object_manager import ObjectManager
|
| 17 |
+
from matanyone.inference.image_feature_store import ImageFeatureStore
|
| 18 |
+
from matanyone.model.matanyone import MatAnyone
|
| 19 |
+
from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate
|
| 20 |
+
from matanyone.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos
|
| 21 |
+
|
| 22 |
+
log = logging.getLogger()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class InferenceCore:
|
| 26 |
+
|
| 27 |
+
def __init__(self,
|
| 28 |
+
network: Union[MatAnyone,str],
|
| 29 |
+
cfg: DictConfig = None,
|
| 30 |
+
*,
|
| 31 |
+
image_feature_store: ImageFeatureStore = None,
|
| 32 |
+
device: Union[str, torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
):
|
| 34 |
+
if isinstance(network, str):
|
| 35 |
+
network = MatAnyone.from_pretrained(network)
|
| 36 |
+
network.to(device)
|
| 37 |
+
network.eval()
|
| 38 |
+
self.network = network
|
| 39 |
+
cfg = cfg if cfg is not None else network.cfg
|
| 40 |
+
self.cfg = cfg
|
| 41 |
+
self.mem_every = cfg.mem_every
|
| 42 |
+
stagger_updates = cfg.stagger_updates
|
| 43 |
+
self.chunk_size = cfg.chunk_size
|
| 44 |
+
self.save_aux = cfg.save_aux
|
| 45 |
+
self.max_internal_size = cfg.max_internal_size
|
| 46 |
+
self.flip_aug = cfg.flip_aug
|
| 47 |
+
|
| 48 |
+
self.curr_ti = -1
|
| 49 |
+
self.last_mem_ti = 0
|
| 50 |
+
# at which time indices should we update the sensory memory
|
| 51 |
+
if stagger_updates >= self.mem_every:
|
| 52 |
+
self.stagger_ti = set(range(1, self.mem_every + 1))
|
| 53 |
+
else:
|
| 54 |
+
self.stagger_ti = set(
|
| 55 |
+
np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
|
| 56 |
+
self.object_manager = ObjectManager()
|
| 57 |
+
self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
|
| 58 |
+
|
| 59 |
+
if image_feature_store is None:
|
| 60 |
+
self.image_feature_store = ImageFeatureStore(self.network)
|
| 61 |
+
else:
|
| 62 |
+
self.image_feature_store = image_feature_store
|
| 63 |
+
|
| 64 |
+
self.last_mask = None
|
| 65 |
+
self.last_pix_feat = None
|
| 66 |
+
self.last_msk_value = None
|
| 67 |
+
|
| 68 |
+
def clear_memory(self):
|
| 69 |
+
self.curr_ti = -1
|
| 70 |
+
self.last_mem_ti = 0
|
| 71 |
+
self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
|
| 72 |
+
|
| 73 |
+
def clear_non_permanent_memory(self):
|
| 74 |
+
self.curr_ti = -1
|
| 75 |
+
self.last_mem_ti = 0
|
| 76 |
+
self.memory.clear_non_permanent_memory()
|
| 77 |
+
|
| 78 |
+
def clear_sensory_memory(self):
|
| 79 |
+
self.curr_ti = -1
|
| 80 |
+
self.last_mem_ti = 0
|
| 81 |
+
self.memory.clear_sensory_memory()
|
| 82 |
+
|
| 83 |
+
def update_config(self, cfg):
|
| 84 |
+
self.mem_every = cfg['mem_every']
|
| 85 |
+
self.memory.update_config(cfg)
|
| 86 |
+
|
| 87 |
+
def clear_temp_mem(self):
|
| 88 |
+
self.memory.clear_work_mem()
|
| 89 |
+
# self.object_manager = ObjectManager()
|
| 90 |
+
self.memory.clear_obj_mem()
|
| 91 |
+
# self.memory.clear_sensory_memory()
|
| 92 |
+
|
| 93 |
+
def _add_memory(self,
|
| 94 |
+
image: torch.Tensor,
|
| 95 |
+
pix_feat: torch.Tensor,
|
| 96 |
+
prob: torch.Tensor,
|
| 97 |
+
key: torch.Tensor,
|
| 98 |
+
shrinkage: torch.Tensor,
|
| 99 |
+
selection: torch.Tensor,
|
| 100 |
+
*,
|
| 101 |
+
is_deep_update: bool = True,
|
| 102 |
+
force_permanent: bool = False) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Memorize the given segmentation in all memory stores.
|
| 105 |
+
|
| 106 |
+
The batch dimension is 1 if flip augmentation is not used.
|
| 107 |
+
image: RGB image, (1/2)*3*H*W
|
| 108 |
+
pix_feat: from the key encoder, (1/2)*_*H*W
|
| 109 |
+
prob: (1/2)*num_objects*H*W, in [0, 1]
|
| 110 |
+
key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
|
| 111 |
+
selection can be None if not using long-term memory
|
| 112 |
+
is_deep_update: whether to use deep update (e.g. with the mask encoder)
|
| 113 |
+
force_permanent: whether to force the memory to be permanent
|
| 114 |
+
"""
|
| 115 |
+
if prob.shape[1] == 0:
|
| 116 |
+
# nothing to add
|
| 117 |
+
log.warn('Trying to add an empty object mask to memory!')
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
if force_permanent:
|
| 121 |
+
as_permanent = 'all'
|
| 122 |
+
else:
|
| 123 |
+
as_permanent = 'first'
|
| 124 |
+
|
| 125 |
+
self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
|
| 126 |
+
msk_value, sensory, obj_value, _ = self.network.encode_mask(
|
| 127 |
+
image,
|
| 128 |
+
pix_feat,
|
| 129 |
+
self.memory.get_sensory(self.object_manager.all_obj_ids),
|
| 130 |
+
prob,
|
| 131 |
+
deep_update=is_deep_update,
|
| 132 |
+
chunk_size=self.chunk_size,
|
| 133 |
+
need_weights=self.save_aux)
|
| 134 |
+
self.memory.add_memory(key,
|
| 135 |
+
shrinkage,
|
| 136 |
+
msk_value,
|
| 137 |
+
obj_value,
|
| 138 |
+
self.object_manager.all_obj_ids,
|
| 139 |
+
selection=selection,
|
| 140 |
+
as_permanent=as_permanent)
|
| 141 |
+
self.last_mem_ti = self.curr_ti
|
| 142 |
+
if is_deep_update:
|
| 143 |
+
self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
|
| 144 |
+
self.last_msk_value = msk_value
|
| 145 |
+
|
| 146 |
+
def _segment(self,
|
| 147 |
+
key: torch.Tensor,
|
| 148 |
+
selection: torch.Tensor,
|
| 149 |
+
pix_feat: torch.Tensor,
|
| 150 |
+
ms_features: Iterable[torch.Tensor],
|
| 151 |
+
update_sensory: bool = True) -> torch.Tensor:
|
| 152 |
+
"""
|
| 153 |
+
Produce a segmentation using the given features and the memory
|
| 154 |
+
|
| 155 |
+
The batch dimension is 1 if flip augmentation is not used.
|
| 156 |
+
key/selection: for anisotropic l2: (1/2) * _ * H * W
|
| 157 |
+
pix_feat: from the key encoder, (1/2) * _ * H * W
|
| 158 |
+
ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
|
| 159 |
+
with strides 16, 8, and 4 respectively
|
| 160 |
+
update_sensory: whether to update the sensory memory
|
| 161 |
+
|
| 162 |
+
Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
|
| 163 |
+
"""
|
| 164 |
+
bs = key.shape[0]
|
| 165 |
+
if self.flip_aug:
|
| 166 |
+
assert bs == 2
|
| 167 |
+
else:
|
| 168 |
+
assert bs == 1
|
| 169 |
+
|
| 170 |
+
if not self.memory.engaged:
|
| 171 |
+
log.warn('Trying to segment without any memory!')
|
| 172 |
+
return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
|
| 173 |
+
device=key.device,
|
| 174 |
+
dtype=key.dtype)
|
| 175 |
+
|
| 176 |
+
uncert_output = None
|
| 177 |
+
|
| 178 |
+
if self.curr_ti == 0: # ONLY for the first frame for prediction
|
| 179 |
+
memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output)
|
| 180 |
+
else:
|
| 181 |
+
memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti,
|
| 182 |
+
last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask)
|
| 183 |
+
memory_readout = self.object_manager.realize_dict(memory_readout)
|
| 184 |
+
|
| 185 |
+
sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
|
| 186 |
+
memory_readout,
|
| 187 |
+
self.memory.get_sensory(
|
| 188 |
+
self.object_manager.all_obj_ids),
|
| 189 |
+
chunk_size=self.chunk_size,
|
| 190 |
+
update_sensory=update_sensory)
|
| 191 |
+
# remove batch dim
|
| 192 |
+
if self.flip_aug:
|
| 193 |
+
# average predictions of the non-flipped and flipped version
|
| 194 |
+
pred_prob_with_bg = (pred_prob_with_bg[0] +
|
| 195 |
+
torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
|
| 196 |
+
else:
|
| 197 |
+
pred_prob_with_bg = pred_prob_with_bg[0]
|
| 198 |
+
if update_sensory:
|
| 199 |
+
self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
|
| 200 |
+
return pred_prob_with_bg
|
| 201 |
+
|
| 202 |
+
def pred_all_flow(self, images):
|
| 203 |
+
self.total_len = images.shape[0]
|
| 204 |
+
images, self.pad = pad_divide_by(images, 16)
|
| 205 |
+
images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w)
|
| 206 |
+
|
| 207 |
+
self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images)
|
| 208 |
+
|
| 209 |
+
def encode_all_images(self, images):
|
| 210 |
+
images, self.pad = pad_divide_by(images, 16)
|
| 211 |
+
self.image_feature_store.get_all_features(images) # t c h w
|
| 212 |
+
return images
|
| 213 |
+
|
| 214 |
+
def step(self,
|
| 215 |
+
image: torch.Tensor,
|
| 216 |
+
mask: Optional[torch.Tensor] = None,
|
| 217 |
+
objects: Optional[List[int]] = None,
|
| 218 |
+
*,
|
| 219 |
+
idx_mask: bool = False,
|
| 220 |
+
end: bool = False,
|
| 221 |
+
delete_buffer: bool = True,
|
| 222 |
+
force_permanent: bool = False,
|
| 223 |
+
matting: bool = True,
|
| 224 |
+
first_frame_pred: bool = False) -> torch.Tensor:
|
| 225 |
+
"""
|
| 226 |
+
Take a step with a new incoming image.
|
| 227 |
+
If there is an incoming mask with new objects, we will memorize them.
|
| 228 |
+
If there is no incoming mask, we will segment the image using the memory.
|
| 229 |
+
In both cases, we will update the memory and return a segmentation.
|
| 230 |
+
|
| 231 |
+
image: 3*H*W
|
| 232 |
+
mask: H*W (if idx mask) or len(objects)*H*W or None
|
| 233 |
+
objects: list of object ids that are valid in the mask Tensor.
|
| 234 |
+
The ids themselves do not need to be consecutive/in order, but they need to be
|
| 235 |
+
in the same position in the list as the corresponding mask
|
| 236 |
+
in the tensor in non-idx-mask mode.
|
| 237 |
+
objects is ignored if the mask is None.
|
| 238 |
+
If idx_mask is False and objects is None, we sequentially infer the object ids.
|
| 239 |
+
idx_mask: if True, mask is expected to contain an object id at every pixel.
|
| 240 |
+
If False, mask should have multiple channels with each channel representing one object.
|
| 241 |
+
end: if we are at the end of the sequence, we do not need to update memory
|
| 242 |
+
if unsure just set it to False
|
| 243 |
+
delete_buffer: whether to delete the image feature buffer after this step
|
| 244 |
+
force_permanent: the memory recorded this frame will be added to the permanent memory
|
| 245 |
+
"""
|
| 246 |
+
if objects is None and mask is not None:
|
| 247 |
+
assert not idx_mask
|
| 248 |
+
objects = list(range(1, mask.shape[0] + 1))
|
| 249 |
+
|
| 250 |
+
# resize input if needed -- currently only used for the GUI
|
| 251 |
+
resize_needed = False
|
| 252 |
+
if self.max_internal_size > 0:
|
| 253 |
+
h, w = image.shape[-2:]
|
| 254 |
+
min_side = min(h, w)
|
| 255 |
+
if min_side > self.max_internal_size:
|
| 256 |
+
resize_needed = True
|
| 257 |
+
new_h = int(h / min_side * self.max_internal_size)
|
| 258 |
+
new_w = int(w / min_side * self.max_internal_size)
|
| 259 |
+
image = F.interpolate(image.unsqueeze(0),
|
| 260 |
+
size=(new_h, new_w),
|
| 261 |
+
mode='bilinear',
|
| 262 |
+
align_corners=False)[0]
|
| 263 |
+
if mask is not None:
|
| 264 |
+
if idx_mask:
|
| 265 |
+
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
|
| 266 |
+
size=(new_h, new_w),
|
| 267 |
+
mode='nearest-exact',
|
| 268 |
+
align_corners=False)[0, 0].round().long()
|
| 269 |
+
else:
|
| 270 |
+
mask = F.interpolate(mask.unsqueeze(0),
|
| 271 |
+
size=(new_h, new_w),
|
| 272 |
+
mode='bilinear',
|
| 273 |
+
align_corners=False)[0]
|
| 274 |
+
|
| 275 |
+
self.curr_ti += 1
|
| 276 |
+
|
| 277 |
+
image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!!
|
| 278 |
+
image = image.unsqueeze(0) # add the batch dimension
|
| 279 |
+
if self.flip_aug:
|
| 280 |
+
image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
|
| 281 |
+
|
| 282 |
+
# whether to update the working memory
|
| 283 |
+
is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
|
| 284 |
+
(mask is not None)) and (not end)
|
| 285 |
+
# segment when there is no input mask or when the input mask is incomplete
|
| 286 |
+
need_segment = (mask is None) or (self.object_manager.num_obj > 0
|
| 287 |
+
and not self.object_manager.has_all(objects))
|
| 288 |
+
update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
|
| 289 |
+
|
| 290 |
+
# reinit if it is the first frame for prediction
|
| 291 |
+
if first_frame_pred:
|
| 292 |
+
self.curr_ti = 0
|
| 293 |
+
self.last_mem_ti = 0
|
| 294 |
+
is_mem_frame = True
|
| 295 |
+
need_segment = True
|
| 296 |
+
update_sensory = True
|
| 297 |
+
|
| 298 |
+
# encoding the image
|
| 299 |
+
ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
|
| 300 |
+
key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
|
| 301 |
+
|
| 302 |
+
# segmentation from memory if needed
|
| 303 |
+
if need_segment:
|
| 304 |
+
pred_prob_with_bg = self._segment(key,
|
| 305 |
+
selection,
|
| 306 |
+
pix_feat,
|
| 307 |
+
ms_feat,
|
| 308 |
+
update_sensory=update_sensory)
|
| 309 |
+
|
| 310 |
+
# use the input mask if provided
|
| 311 |
+
if mask is not None:
|
| 312 |
+
# inform the manager of the new objects, and get a list of temporary id
|
| 313 |
+
# temporary ids -- indicates the position of objects in the tensor
|
| 314 |
+
# (starts with 1 due to the background channel)
|
| 315 |
+
corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
|
| 316 |
+
|
| 317 |
+
mask, _ = pad_divide_by(mask, 16)
|
| 318 |
+
if need_segment:
|
| 319 |
+
# merge predicted mask with the incomplete input mask
|
| 320 |
+
pred_prob_no_bg = pred_prob_with_bg[1:]
|
| 321 |
+
# use the mutual exclusivity of segmentation
|
| 322 |
+
if idx_mask:
|
| 323 |
+
pred_prob_no_bg[:, mask > 0] = 0
|
| 324 |
+
else:
|
| 325 |
+
pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
|
| 326 |
+
|
| 327 |
+
new_masks = []
|
| 328 |
+
for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
|
| 329 |
+
if idx_mask:
|
| 330 |
+
this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
|
| 331 |
+
else:
|
| 332 |
+
this_mask = mask[tmp_id]
|
| 333 |
+
if tmp_id > pred_prob_no_bg.shape[0]:
|
| 334 |
+
new_masks.append(this_mask.unsqueeze(0))
|
| 335 |
+
else:
|
| 336 |
+
# +1 for padding the background channel
|
| 337 |
+
pred_prob_no_bg[tmp_id - 1] = this_mask
|
| 338 |
+
# new_masks are always in the order of tmp_id
|
| 339 |
+
mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
|
| 340 |
+
elif idx_mask:
|
| 341 |
+
# simply convert cls to one-hot representation
|
| 342 |
+
if len(objects) == 0:
|
| 343 |
+
if delete_buffer:
|
| 344 |
+
self.image_feature_store.delete(self.curr_ti)
|
| 345 |
+
log.warn('Trying to insert an empty mask as memory!')
|
| 346 |
+
return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
|
| 347 |
+
device=key.device,
|
| 348 |
+
dtype=key.dtype)
|
| 349 |
+
mask = torch.stack(
|
| 350 |
+
[mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
|
| 351 |
+
dim=0)
|
| 352 |
+
if matting:
|
| 353 |
+
mask = mask.unsqueeze(0).float() / 255.
|
| 354 |
+
pred_prob_with_bg = torch.cat([1-mask, mask], 0)
|
| 355 |
+
else:
|
| 356 |
+
pred_prob_with_bg = aggregate(mask, dim=0)
|
| 357 |
+
pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
|
| 358 |
+
|
| 359 |
+
self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
|
| 360 |
+
if self.flip_aug:
|
| 361 |
+
self.last_mask = torch.cat(
|
| 362 |
+
[self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
|
| 363 |
+
self.last_pix_feat = pix_feat
|
| 364 |
+
|
| 365 |
+
# save as memory if needed
|
| 366 |
+
if is_mem_frame or force_permanent:
|
| 367 |
+
# clear the memory for given mask and add the first predicted mask
|
| 368 |
+
if first_frame_pred:
|
| 369 |
+
self.clear_temp_mem()
|
| 370 |
+
self._add_memory(image,
|
| 371 |
+
pix_feat,
|
| 372 |
+
self.last_mask,
|
| 373 |
+
key,
|
| 374 |
+
shrinkage,
|
| 375 |
+
selection,
|
| 376 |
+
force_permanent=force_permanent,
|
| 377 |
+
is_deep_update=True)
|
| 378 |
+
else: # compute self.last_msk_value for non-memory frame
|
| 379 |
+
msk_value, _, _, _ = self.network.encode_mask(
|
| 380 |
+
image,
|
| 381 |
+
pix_feat,
|
| 382 |
+
self.memory.get_sensory(self.object_manager.all_obj_ids),
|
| 383 |
+
self.last_mask,
|
| 384 |
+
deep_update=False,
|
| 385 |
+
chunk_size=self.chunk_size,
|
| 386 |
+
need_weights=self.save_aux)
|
| 387 |
+
self.last_msk_value = msk_value
|
| 388 |
+
|
| 389 |
+
if delete_buffer:
|
| 390 |
+
self.image_feature_store.delete(self.curr_ti)
|
| 391 |
+
|
| 392 |
+
output_prob = unpad(pred_prob_with_bg, self.pad)
|
| 393 |
+
if resize_needed:
|
| 394 |
+
# restore output to the original size
|
| 395 |
+
output_prob = F.interpolate(output_prob.unsqueeze(0),
|
| 396 |
+
size=(h, w),
|
| 397 |
+
mode='bilinear',
|
| 398 |
+
align_corners=False)[0]
|
| 399 |
+
|
| 400 |
+
return output_prob
|
| 401 |
+
|
| 402 |
+
def delete_objects(self, objects: List[int]) -> None:
|
| 403 |
+
"""
|
| 404 |
+
Delete the given objects from the memory.
|
| 405 |
+
"""
|
| 406 |
+
self.object_manager.delete_objects(objects)
|
| 407 |
+
self.memory.purge_except(self.object_manager.all_obj_ids)
|
| 408 |
+
|
| 409 |
+
def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor:
|
| 410 |
+
if matting:
|
| 411 |
+
new_mask = output_prob[1:].squeeze(0)
|
| 412 |
+
else:
|
| 413 |
+
mask = torch.argmax(output_prob, dim=0)
|
| 414 |
+
|
| 415 |
+
# index in tensor != object id -- remap the ids here
|
| 416 |
+
new_mask = torch.zeros_like(mask)
|
| 417 |
+
for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
|
| 418 |
+
new_mask[mask == tmp_id] = obj.id
|
| 419 |
+
|
| 420 |
+
return new_mask
|
| 421 |
+
|
| 422 |
+
@torch.inference_mode()
|
| 423 |
+
@torch.amp.autocast("cuda")
|
| 424 |
+
def process_video(
|
| 425 |
+
self,
|
| 426 |
+
input_path: str,
|
| 427 |
+
mask_path: str,
|
| 428 |
+
output_path: str = None,
|
| 429 |
+
n_warmup: int = 10,
|
| 430 |
+
r_erode: int = 10,
|
| 431 |
+
r_dilate: int = 10,
|
| 432 |
+
suffix: str = "",
|
| 433 |
+
save_image: bool = False,
|
| 434 |
+
max_size: int = -1,
|
| 435 |
+
) -> Tuple:
|
| 436 |
+
"""
|
| 437 |
+
Process a video for object segmentation and matting.
|
| 438 |
+
This method processes a video file by performing object segmentation and matting on each frame.
|
| 439 |
+
It supports warmup frames, mask erosion/dilation, and various output options.
|
| 440 |
+
Args:
|
| 441 |
+
input_path (str): Path to the input video file
|
| 442 |
+
mask_path (str): Path to the mask image file used for initial segmentation
|
| 443 |
+
output_path (str, optional): Directory path where output files will be saved. Defaults to a temporary directory
|
| 444 |
+
n_warmup (int, optional): Number of warmup frames to use. Defaults to 10
|
| 445 |
+
r_erode (int, optional): Erosion radius for mask processing. Defaults to 10
|
| 446 |
+
r_dilate (int, optional): Dilation radius for mask processing. Defaults to 10
|
| 447 |
+
suffix (str, optional): Suffix to append to output filename. Defaults to ""
|
| 448 |
+
save_image (bool, optional): Whether to save individual frames. Defaults to False
|
| 449 |
+
max_size (int, optional): Maximum size for frame dimension. Use -1 for no limit. Defaults to -1
|
| 450 |
+
Returns:
|
| 451 |
+
Tuple[str, str]: A tuple containing:
|
| 452 |
+
- Path to the output foreground video file (str)
|
| 453 |
+
- Path to the output alpha matte video file (str)
|
| 454 |
+
Output:
|
| 455 |
+
- Saves processed video files with foreground (_fgr) and alpha matte (_pha)
|
| 456 |
+
- If save_image=True, saves individual frames in separate directories
|
| 457 |
+
"""
|
| 458 |
+
output_path = output_path if output_path is not None else tempfile.TemporaryDirectory().name
|
| 459 |
+
r_erode = int(r_erode)
|
| 460 |
+
r_dilate = int(r_dilate)
|
| 461 |
+
n_warmup = int(n_warmup)
|
| 462 |
+
max_size = int(max_size)
|
| 463 |
+
|
| 464 |
+
vframes, fps, length, video_name = read_frame_from_videos(input_path)
|
| 465 |
+
repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1)
|
| 466 |
+
vframes = torch.cat([repeated_frames, vframes], dim=0).float()
|
| 467 |
+
length += n_warmup
|
| 468 |
+
|
| 469 |
+
new_h, new_w = vframes.shape[-2:]
|
| 470 |
+
if max_size > 0:
|
| 471 |
+
h, w = new_h, new_w
|
| 472 |
+
min_side = min(h, w)
|
| 473 |
+
if min_side > max_size:
|
| 474 |
+
new_h = int(h / min_side * max_size)
|
| 475 |
+
new_w = int(w / min_side * max_size)
|
| 476 |
+
vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area")
|
| 477 |
+
|
| 478 |
+
os.makedirs(output_path, exist_ok=True)
|
| 479 |
+
if suffix:
|
| 480 |
+
video_name = f"{video_name}_{suffix}"
|
| 481 |
+
if save_image:
|
| 482 |
+
os.makedirs(f"{output_path}/{video_name}", exist_ok=True)
|
| 483 |
+
os.makedirs(f"{output_path}/{video_name}/pha", exist_ok=True)
|
| 484 |
+
os.makedirs(f"{output_path}/{video_name}/fgr", exist_ok=True)
|
| 485 |
+
|
| 486 |
+
mask = np.array(Image.open(mask_path).convert("L"))
|
| 487 |
+
if r_dilate > 0:
|
| 488 |
+
mask = gen_dilate(mask, r_dilate, r_dilate)
|
| 489 |
+
if r_erode > 0:
|
| 490 |
+
mask = gen_erosion(mask, r_erode, r_erode)
|
| 491 |
+
|
| 492 |
+
mask = torch.from_numpy(mask).cuda()
|
| 493 |
+
if max_size > 0:
|
| 494 |
+
mask = F.interpolate(
|
| 495 |
+
mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest"
|
| 496 |
+
)[0, 0]
|
| 497 |
+
|
| 498 |
+
bgr = (np.array([120, 255, 155], dtype=np.float32) / 255).reshape((1, 1, 3))
|
| 499 |
+
objects = [1]
|
| 500 |
+
|
| 501 |
+
phas = []
|
| 502 |
+
fgrs = []
|
| 503 |
+
for ti in tqdm(range(length)):
|
| 504 |
+
image = vframes[ti]
|
| 505 |
+
image_np = np.array(image.permute(1, 2, 0))
|
| 506 |
+
image = (image / 255.0).cuda().float()
|
| 507 |
+
|
| 508 |
+
if ti == 0:
|
| 509 |
+
output_prob = self.step(image, mask, objects=objects)
|
| 510 |
+
output_prob = self.step(image, first_frame_pred=True)
|
| 511 |
+
else:
|
| 512 |
+
if ti <= n_warmup:
|
| 513 |
+
output_prob = self.step(image, first_frame_pred=True)
|
| 514 |
+
else:
|
| 515 |
+
output_prob = self.step(image)
|
| 516 |
+
|
| 517 |
+
mask = self.output_prob_to_mask(output_prob)
|
| 518 |
+
pha = mask.unsqueeze(2).cpu().numpy()
|
| 519 |
+
com_np = image_np / 255.0 * pha + bgr * (1 - pha)
|
| 520 |
+
|
| 521 |
+
if ti > (n_warmup - 1):
|
| 522 |
+
com_np = (com_np * 255).astype(np.uint8)
|
| 523 |
+
pha = (pha * 255).astype(np.uint8)
|
| 524 |
+
fgrs.append(com_np)
|
| 525 |
+
phas.append(pha)
|
| 526 |
+
if save_image:
|
| 527 |
+
cv2.imwrite(
|
| 528 |
+
f"{output_path}/{video_name}/pha/{str(ti - n_warmup).zfill(5)}.png",
|
| 529 |
+
pha,
|
| 530 |
+
)
|
| 531 |
+
cv2.imwrite(
|
| 532 |
+
f"{output_path}/{video_name}/fgr/{str(ti - n_warmup).zfill(5)}.png",
|
| 533 |
+
com_np[..., [2, 1, 0]],
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
fgrs = np.array(fgrs)
|
| 537 |
+
phas = np.array(phas)
|
| 538 |
+
|
| 539 |
+
fgr_filename = f"{output_path}/{video_name}_fgr.mp4"
|
| 540 |
+
alpha_filename = f"{output_path}/{video_name}_pha.mp4"
|
| 541 |
+
|
| 542 |
+
imageio.mimwrite(fgr_filename, fgrs, fps=fps, quality=7)
|
| 543 |
+
imageio.mimwrite(alpha_filename, phas, fps=fps, quality=7)
|
| 544 |
+
|
| 545 |
+
return (fgr_filename,alpha_filename)
|
hf_space/third_party/matanyone/inference/kv_memory_store.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Literal
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _add_last_dim(dictionary, key, new_value, prepend=False):
|
| 7 |
+
# append/prepend a new value to the last dimension of a tensor in a dictionary
|
| 8 |
+
# if the key does not exist, put the new value in
|
| 9 |
+
# append by default
|
| 10 |
+
if key in dictionary:
|
| 11 |
+
dictionary[key] = torch.cat([dictionary[key], new_value], -1)
|
| 12 |
+
else:
|
| 13 |
+
dictionary[key] = new_value
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class KeyValueMemoryStore:
|
| 17 |
+
"""
|
| 18 |
+
Works for key/value pairs type storage
|
| 19 |
+
e.g., working and long-term memory
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, save_selection: bool = False, save_usage: bool = False):
|
| 22 |
+
"""
|
| 23 |
+
We store keys and values of objects that first appear in the same frame in a bucket.
|
| 24 |
+
Each bucket contains a set of object ids.
|
| 25 |
+
Each bucket is associated with a single key tensor
|
| 26 |
+
and a dictionary of value tensors indexed by object id.
|
| 27 |
+
|
| 28 |
+
The keys and values are stored as the concatenation of a permanent part and a temporary part.
|
| 29 |
+
"""
|
| 30 |
+
self.save_selection = save_selection
|
| 31 |
+
self.save_usage = save_usage
|
| 32 |
+
|
| 33 |
+
self.global_bucket_id = 0 # does not reduce even if buckets are removed
|
| 34 |
+
self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
|
| 35 |
+
self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
|
| 36 |
+
self.v: Dict[int, torch.Tensor] = {} # indexed by object id
|
| 37 |
+
|
| 38 |
+
# indexed by bucket id; the end point of permanent memory
|
| 39 |
+
self.perm_end_pt: Dict[int, int] = defaultdict(int)
|
| 40 |
+
|
| 41 |
+
# shrinkage and selection are just like the keys
|
| 42 |
+
self.s = {}
|
| 43 |
+
if self.save_selection:
|
| 44 |
+
self.e = {} # does not contain the permanent memory part
|
| 45 |
+
|
| 46 |
+
# usage
|
| 47 |
+
if self.save_usage:
|
| 48 |
+
self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
|
| 49 |
+
self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
|
| 50 |
+
|
| 51 |
+
def add(self,
|
| 52 |
+
key: torch.Tensor,
|
| 53 |
+
values: Dict[int, torch.Tensor],
|
| 54 |
+
shrinkage: torch.Tensor,
|
| 55 |
+
selection: torch.Tensor,
|
| 56 |
+
supposed_bucket_id: int = -1,
|
| 57 |
+
as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
|
| 58 |
+
"""
|
| 59 |
+
key: (1/2)*C*N
|
| 60 |
+
values: dict of values ((1/2)*C*N), object ids are used as keys
|
| 61 |
+
shrinkage: (1/2)*1*N
|
| 62 |
+
selection: (1/2)*C*N
|
| 63 |
+
|
| 64 |
+
supposed_bucket_id: used to sync the bucket id between working and long-term memory
|
| 65 |
+
if provided, the input should all be in a single bucket indexed by this id
|
| 66 |
+
as_permanent: whether to store the input as permanent memory
|
| 67 |
+
'no': don't
|
| 68 |
+
'first': only store it as permanent memory if the bucket is empty
|
| 69 |
+
'all': always store it as permanent memory
|
| 70 |
+
"""
|
| 71 |
+
bs = key.shape[0]
|
| 72 |
+
ne = key.shape[-1]
|
| 73 |
+
assert len(key.shape) == 3
|
| 74 |
+
assert len(shrinkage.shape) == 3
|
| 75 |
+
assert not self.save_selection or len(selection.shape) == 3
|
| 76 |
+
assert as_permanent in ['no', 'first', 'all']
|
| 77 |
+
|
| 78 |
+
# add the value and create new buckets if necessary
|
| 79 |
+
if supposed_bucket_id >= 0:
|
| 80 |
+
enabled_buckets = [supposed_bucket_id]
|
| 81 |
+
bucket_exist = supposed_bucket_id in self.buckets
|
| 82 |
+
for obj, value in values.items():
|
| 83 |
+
if bucket_exist:
|
| 84 |
+
assert obj in self.v
|
| 85 |
+
assert obj in self.buckets[supposed_bucket_id]
|
| 86 |
+
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
|
| 87 |
+
else:
|
| 88 |
+
assert obj not in self.v
|
| 89 |
+
self.v[obj] = value
|
| 90 |
+
self.buckets[supposed_bucket_id] = list(values.keys())
|
| 91 |
+
else:
|
| 92 |
+
new_bucket_id = None
|
| 93 |
+
enabled_buckets = set()
|
| 94 |
+
for obj, value in values.items():
|
| 95 |
+
assert len(value.shape) == 3
|
| 96 |
+
if obj in self.v:
|
| 97 |
+
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
|
| 98 |
+
bucket_used = [
|
| 99 |
+
bucket_id for bucket_id, object_ids in self.buckets.items()
|
| 100 |
+
if obj in object_ids
|
| 101 |
+
]
|
| 102 |
+
assert len(bucket_used) == 1 # each object should only be in one bucket
|
| 103 |
+
enabled_buckets.add(bucket_used[0])
|
| 104 |
+
else:
|
| 105 |
+
self.v[obj] = value
|
| 106 |
+
if new_bucket_id is None:
|
| 107 |
+
# create new bucket
|
| 108 |
+
new_bucket_id = self.global_bucket_id
|
| 109 |
+
self.global_bucket_id += 1
|
| 110 |
+
self.buckets[new_bucket_id] = []
|
| 111 |
+
# put the new object into the corresponding bucket
|
| 112 |
+
self.buckets[new_bucket_id].append(obj)
|
| 113 |
+
enabled_buckets.add(new_bucket_id)
|
| 114 |
+
|
| 115 |
+
# increment the permanent size if necessary
|
| 116 |
+
add_as_permanent = {} # indexed by bucket id
|
| 117 |
+
for bucket_id in enabled_buckets:
|
| 118 |
+
add_as_permanent[bucket_id] = False
|
| 119 |
+
if as_permanent == 'all':
|
| 120 |
+
self.perm_end_pt[bucket_id] += ne
|
| 121 |
+
add_as_permanent[bucket_id] = True
|
| 122 |
+
elif as_permanent == 'first':
|
| 123 |
+
if self.perm_end_pt[bucket_id] == 0:
|
| 124 |
+
self.perm_end_pt[bucket_id] = ne
|
| 125 |
+
add_as_permanent[bucket_id] = True
|
| 126 |
+
|
| 127 |
+
# create new counters for usage if necessary
|
| 128 |
+
if self.save_usage and as_permanent != 'all':
|
| 129 |
+
new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
|
| 130 |
+
new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
|
| 131 |
+
|
| 132 |
+
# add the key to every bucket
|
| 133 |
+
for bucket_id in self.buckets:
|
| 134 |
+
if bucket_id not in enabled_buckets:
|
| 135 |
+
# if we are not adding new values to a bucket, we should skip it
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
_add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
|
| 139 |
+
_add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
|
| 140 |
+
if not add_as_permanent[bucket_id]:
|
| 141 |
+
if self.save_selection:
|
| 142 |
+
_add_last_dim(self.e, bucket_id, selection)
|
| 143 |
+
if self.save_usage:
|
| 144 |
+
_add_last_dim(self.use_cnt, bucket_id, new_count)
|
| 145 |
+
_add_last_dim(self.life_cnt, bucket_id, new_life)
|
| 146 |
+
|
| 147 |
+
def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
|
| 148 |
+
# increase all life count by 1
|
| 149 |
+
# increase use of indexed elements
|
| 150 |
+
if not self.save_usage:
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
usage = usage[:, self.perm_end_pt[bucket_id]:]
|
| 154 |
+
if usage.shape[-1] == 0:
|
| 155 |
+
# if there is no temporary memory, we don't need to update
|
| 156 |
+
return
|
| 157 |
+
self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
|
| 158 |
+
self.life_cnt[bucket_id] += 1
|
| 159 |
+
|
| 160 |
+
def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
|
| 161 |
+
# keep only the temporary elements *outside* of this range (with some boundary conditions)
|
| 162 |
+
# the permanent elements are ignored in this computation
|
| 163 |
+
# i.e., concat (a[:start], a[end:])
|
| 164 |
+
# bucket with size <= min_size are not modified
|
| 165 |
+
|
| 166 |
+
assert start >= 0
|
| 167 |
+
assert end <= 0
|
| 168 |
+
|
| 169 |
+
object_ids = self.buckets[bucket_id]
|
| 170 |
+
bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
|
| 171 |
+
if bucket_num_elements <= min_size:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
if end == 0:
|
| 175 |
+
# negative 0 would not work as the end index!
|
| 176 |
+
# effectively make the second part an empty slice
|
| 177 |
+
end = self.k[bucket_id].shape[-1] + 1
|
| 178 |
+
|
| 179 |
+
p_size = self.perm_end_pt[bucket_id]
|
| 180 |
+
start = start + p_size
|
| 181 |
+
|
| 182 |
+
k = self.k[bucket_id]
|
| 183 |
+
s = self.s[bucket_id]
|
| 184 |
+
if self.save_selection:
|
| 185 |
+
e = self.e[bucket_id]
|
| 186 |
+
if self.save_usage:
|
| 187 |
+
use_cnt = self.use_cnt[bucket_id]
|
| 188 |
+
life_cnt = self.life_cnt[bucket_id]
|
| 189 |
+
|
| 190 |
+
self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
|
| 191 |
+
self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
|
| 192 |
+
if self.save_selection:
|
| 193 |
+
self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
|
| 194 |
+
if self.save_usage:
|
| 195 |
+
self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
|
| 196 |
+
self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
|
| 197 |
+
-1)
|
| 198 |
+
for obj_id in object_ids:
|
| 199 |
+
v = self.v[obj_id]
|
| 200 |
+
self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
|
| 201 |
+
|
| 202 |
+
def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
|
| 203 |
+
self.sieve_by_range(bucket_id, 0, -max_len, max_len)
|
| 204 |
+
|
| 205 |
+
def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
|
| 206 |
+
# for long-term memory only
|
| 207 |
+
object_ids = self.buckets[bucket_id]
|
| 208 |
+
|
| 209 |
+
assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
|
| 210 |
+
|
| 211 |
+
# normalize with life duration
|
| 212 |
+
usage = self.get_usage(bucket_id)
|
| 213 |
+
bs = usage.shape[0]
|
| 214 |
+
|
| 215 |
+
survivals = []
|
| 216 |
+
|
| 217 |
+
for bi in range(bs):
|
| 218 |
+
_, survived = torch.topk(usage[bi], k=max_size)
|
| 219 |
+
survivals.append(survived.flatten())
|
| 220 |
+
assert survived.shape[-1] == survivals[0].shape[-1]
|
| 221 |
+
|
| 222 |
+
self.k[bucket_id] = torch.stack(
|
| 223 |
+
[self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
| 224 |
+
self.s[bucket_id] = torch.stack(
|
| 225 |
+
[self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
| 226 |
+
|
| 227 |
+
if self.save_selection:
|
| 228 |
+
# Long-term memory does not store selection so this should not be needed
|
| 229 |
+
self.e[bucket_id] = torch.stack(
|
| 230 |
+
[self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
| 231 |
+
for obj_id in object_ids:
|
| 232 |
+
self.v[obj_id] = torch.stack(
|
| 233 |
+
[self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
|
| 234 |
+
|
| 235 |
+
self.use_cnt[bucket_id] = torch.stack(
|
| 236 |
+
[self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
|
| 237 |
+
self.life_cnt[bucket_id] = torch.stack(
|
| 238 |
+
[self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
|
| 239 |
+
|
| 240 |
+
def get_usage(self, bucket_id: int) -> torch.Tensor:
|
| 241 |
+
# return normalized usage
|
| 242 |
+
if not self.save_usage:
|
| 243 |
+
raise RuntimeError('I did not count usage!')
|
| 244 |
+
else:
|
| 245 |
+
usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
|
| 246 |
+
return usage
|
| 247 |
+
|
| 248 |
+
def get_all_sliced(
|
| 249 |
+
self, bucket_id: int, start: int, end: int
|
| 250 |
+
) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
|
| 251 |
+
# return k, sk, ek, value, normalized usage in order, sliced by start and end
|
| 252 |
+
# this only queries the temporary memory
|
| 253 |
+
|
| 254 |
+
assert start >= 0
|
| 255 |
+
assert end <= 0
|
| 256 |
+
|
| 257 |
+
p_size = self.perm_end_pt[bucket_id]
|
| 258 |
+
start = start + p_size
|
| 259 |
+
|
| 260 |
+
if end == 0:
|
| 261 |
+
# negative 0 would not work as the end index!
|
| 262 |
+
k = self.k[bucket_id][:, :, start:]
|
| 263 |
+
sk = self.s[bucket_id][:, :, start:]
|
| 264 |
+
ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
|
| 265 |
+
value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
|
| 266 |
+
usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
|
| 267 |
+
else:
|
| 268 |
+
k = self.k[bucket_id][:, :, start:end]
|
| 269 |
+
sk = self.s[bucket_id][:, :, start:end]
|
| 270 |
+
ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
|
| 271 |
+
value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
|
| 272 |
+
usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
|
| 273 |
+
|
| 274 |
+
return k, sk, ek, value, usage
|
| 275 |
+
|
| 276 |
+
def purge_except(self, obj_keep_idx: List[int]):
|
| 277 |
+
# purge certain objects from the memory except the one listed
|
| 278 |
+
obj_keep_idx = set(obj_keep_idx)
|
| 279 |
+
|
| 280 |
+
# remove objects that are not in the keep list from the buckets
|
| 281 |
+
buckets_to_remove = []
|
| 282 |
+
for bucket_id, object_ids in self.buckets.items():
|
| 283 |
+
self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
|
| 284 |
+
if len(self.buckets[bucket_id]) == 0:
|
| 285 |
+
buckets_to_remove.append(bucket_id)
|
| 286 |
+
|
| 287 |
+
# remove object values that are not in the keep list
|
| 288 |
+
self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
|
| 289 |
+
|
| 290 |
+
# remove buckets that are empty
|
| 291 |
+
for bucket_id in buckets_to_remove:
|
| 292 |
+
del self.buckets[bucket_id]
|
| 293 |
+
del self.k[bucket_id]
|
| 294 |
+
del self.s[bucket_id]
|
| 295 |
+
if self.save_selection:
|
| 296 |
+
del self.e[bucket_id]
|
| 297 |
+
if self.save_usage:
|
| 298 |
+
del self.use_cnt[bucket_id]
|
| 299 |
+
del self.life_cnt[bucket_id]
|
| 300 |
+
|
| 301 |
+
def clear_non_permanent_memory(self):
|
| 302 |
+
# clear all non-permanent memory
|
| 303 |
+
for bucket_id in self.buckets:
|
| 304 |
+
self.sieve_by_range(bucket_id, 0, 0, 0)
|
| 305 |
+
|
| 306 |
+
def get_v_size(self, obj_id: int) -> int:
|
| 307 |
+
return self.v[obj_id].shape[-1]
|
| 308 |
+
|
| 309 |
+
def size(self, bucket_id: int) -> int:
|
| 310 |
+
if bucket_id not in self.k:
|
| 311 |
+
return 0
|
| 312 |
+
else:
|
| 313 |
+
return self.k[bucket_id].shape[-1]
|
| 314 |
+
|
| 315 |
+
def perm_size(self, bucket_id: int) -> int:
|
| 316 |
+
return self.perm_end_pt[bucket_id]
|
| 317 |
+
|
| 318 |
+
def non_perm_size(self, bucket_id: int) -> int:
|
| 319 |
+
return self.size(bucket_id) - self.perm_size(bucket_id)
|
| 320 |
+
|
| 321 |
+
def engaged(self, bucket_id: Optional[int] = None) -> bool:
|
| 322 |
+
if bucket_id is None:
|
| 323 |
+
return len(self.buckets) > 0
|
| 324 |
+
else:
|
| 325 |
+
return bucket_id in self.buckets
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def num_objects(self) -> int:
|
| 329 |
+
return len(self.v)
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def key(self) -> Dict[int, torch.Tensor]:
|
| 333 |
+
return self.k
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def value(self) -> Dict[int, torch.Tensor]:
|
| 337 |
+
return self.v
|
| 338 |
+
|
| 339 |
+
@property
|
| 340 |
+
def shrinkage(self) -> Dict[int, torch.Tensor]:
|
| 341 |
+
return self.s
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def selection(self) -> Dict[int, torch.Tensor]:
|
| 345 |
+
return self.e
|
| 346 |
+
|
| 347 |
+
def __contains__(self, key):
|
| 348 |
+
return key in self.v
|
hf_space/third_party/matanyone/inference/memory_manager.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from matanyone.inference.object_manager import ObjectManager
|
| 7 |
+
from matanyone.inference.kv_memory_store import KeyValueMemoryStore
|
| 8 |
+
from matanyone.model.matanyone import MatAnyone
|
| 9 |
+
from matanyone.model.utils.memory_utils import get_similarity, do_softmax
|
| 10 |
+
|
| 11 |
+
log = logging.getLogger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MemoryManager:
|
| 15 |
+
"""
|
| 16 |
+
Manages all three memory stores and the transition between working/long-term memory
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
|
| 19 |
+
self.object_manager = object_manager
|
| 20 |
+
self.sensory_dim = cfg.model.sensory_dim
|
| 21 |
+
self.top_k = cfg.top_k
|
| 22 |
+
self.chunk_size = cfg.chunk_size
|
| 23 |
+
|
| 24 |
+
self.save_aux = cfg.save_aux
|
| 25 |
+
|
| 26 |
+
self.use_long_term = cfg.use_long_term
|
| 27 |
+
self.count_long_term_usage = cfg.long_term.count_usage
|
| 28 |
+
# subtract 1 because the first-frame is now counted as "permanent memory"
|
| 29 |
+
# and is not counted towards max_mem_frames
|
| 30 |
+
# but we want to keep the hyperparameters consistent as before for the same behavior
|
| 31 |
+
if self.use_long_term:
|
| 32 |
+
self.max_mem_frames = cfg.long_term.max_mem_frames - 1
|
| 33 |
+
self.min_mem_frames = cfg.long_term.min_mem_frames - 1
|
| 34 |
+
self.num_prototypes = cfg.long_term.num_prototypes
|
| 35 |
+
self.max_long_tokens = cfg.long_term.max_num_tokens
|
| 36 |
+
self.buffer_tokens = cfg.long_term.buffer_tokens
|
| 37 |
+
else:
|
| 38 |
+
self.max_mem_frames = cfg.max_mem_frames - 1
|
| 39 |
+
|
| 40 |
+
# dimensions will be inferred from input later
|
| 41 |
+
self.CK = self.CV = None
|
| 42 |
+
self.H = self.W = None
|
| 43 |
+
|
| 44 |
+
# The sensory memory is stored as a dictionary indexed by object ids
|
| 45 |
+
# each of shape bs * C^h * H * W
|
| 46 |
+
self.sensory = {}
|
| 47 |
+
|
| 48 |
+
# a dictionary indexed by object ids, each of shape bs * T * Q * C
|
| 49 |
+
self.obj_v = {}
|
| 50 |
+
|
| 51 |
+
self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
|
| 52 |
+
save_usage=self.use_long_term)
|
| 53 |
+
if self.use_long_term:
|
| 54 |
+
self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
|
| 55 |
+
|
| 56 |
+
self.config_stale = True
|
| 57 |
+
self.engaged = False
|
| 58 |
+
|
| 59 |
+
def update_config(self, cfg: DictConfig) -> None:
|
| 60 |
+
self.config_stale = True
|
| 61 |
+
self.top_k = cfg['top_k']
|
| 62 |
+
|
| 63 |
+
assert self.use_long_term == cfg.use_long_term, 'cannot update this'
|
| 64 |
+
assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
|
| 65 |
+
|
| 66 |
+
self.use_long_term = cfg.use_long_term
|
| 67 |
+
self.count_long_term_usage = cfg.long_term.count_usage
|
| 68 |
+
if self.use_long_term:
|
| 69 |
+
self.max_mem_frames = cfg.long_term.max_mem_frames - 1
|
| 70 |
+
self.min_mem_frames = cfg.long_term.min_mem_frames - 1
|
| 71 |
+
self.num_prototypes = cfg.long_term.num_prototypes
|
| 72 |
+
self.max_long_tokens = cfg.long_term.max_num_tokens
|
| 73 |
+
self.buffer_tokens = cfg.long_term.buffer_tokens
|
| 74 |
+
else:
|
| 75 |
+
self.max_mem_frames = cfg.max_mem_frames - 1
|
| 76 |
+
|
| 77 |
+
def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor:
|
| 78 |
+
# affinity: bs*N*HW
|
| 79 |
+
# v: bs*C*N or bs*num_objects*C*N
|
| 80 |
+
# returns bs*C*HW or bs*num_objects*C*HW
|
| 81 |
+
if len(v.shape) == 3:
|
| 82 |
+
# single object
|
| 83 |
+
if uncert_mask is not None:
|
| 84 |
+
return v @ affinity * uncert_mask
|
| 85 |
+
else:
|
| 86 |
+
return v @ affinity
|
| 87 |
+
else:
|
| 88 |
+
bs, num_objects, C, N = v.shape
|
| 89 |
+
v = v.view(bs, num_objects * C, N)
|
| 90 |
+
out = v @ affinity
|
| 91 |
+
if uncert_mask is not None:
|
| 92 |
+
uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1)
|
| 93 |
+
out = out * uncert_mask
|
| 94 |
+
return out.view(bs, num_objects, C, -1)
|
| 95 |
+
|
| 96 |
+
def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
|
| 97 |
+
# -1 because the mask does not contain the background channel
|
| 98 |
+
return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
|
| 99 |
+
|
| 100 |
+
def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
|
| 101 |
+
return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
|
| 102 |
+
|
| 103 |
+
def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
|
| 104 |
+
return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
|
| 105 |
+
|
| 106 |
+
def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
|
| 107 |
+
# All the values that the object ids refer to should have the same shape
|
| 108 |
+
value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
|
| 109 |
+
if self.use_long_term and obj_ids[0] in self.long_mem.value:
|
| 110 |
+
lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
|
| 111 |
+
value = torch.cat([lt_value, value], dim=-1)
|
| 112 |
+
|
| 113 |
+
return value
|
| 114 |
+
|
| 115 |
+
def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
|
| 116 |
+
last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
|
| 117 |
+
"""
|
| 118 |
+
Read from all memory stores and returns a single memory readout tensor for each object
|
| 119 |
+
|
| 120 |
+
pix_feat: (1/2) x C x H x W
|
| 121 |
+
query_key: (1/2) x C^k x H x W
|
| 122 |
+
selection: (1/2) x C^k x H x W
|
| 123 |
+
last_mask: (1/2) x num_objects x H x W (at stride 16)
|
| 124 |
+
return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
|
| 125 |
+
"""
|
| 126 |
+
h, w = pix_feat.shape[-2:]
|
| 127 |
+
bs = pix_feat.shape[0]
|
| 128 |
+
assert last_mask.shape[0] == bs
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
Compute affinity and perform readout
|
| 132 |
+
"""
|
| 133 |
+
all_readout_mem = {}
|
| 134 |
+
buckets = self.work_mem.buckets
|
| 135 |
+
for bucket_id, bucket in buckets.items():
|
| 136 |
+
|
| 137 |
+
if self.chunk_size < 1:
|
| 138 |
+
object_chunks = [bucket]
|
| 139 |
+
else:
|
| 140 |
+
object_chunks = [
|
| 141 |
+
bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
for objects in object_chunks:
|
| 145 |
+
this_sensory = self._get_sensory_by_ids(objects)
|
| 146 |
+
this_last_mask = self._get_mask_by_ids(last_mask, objects)
|
| 147 |
+
this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
|
| 148 |
+
pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory,
|
| 149 |
+
this_last_mask)
|
| 150 |
+
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
|
| 151 |
+
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
|
| 152 |
+
for i, obj in enumerate(objects):
|
| 153 |
+
all_readout_mem[obj] = readout_memory[:, i]
|
| 154 |
+
|
| 155 |
+
if self.save_aux:
|
| 156 |
+
aux_output = {
|
| 157 |
+
# 'sensory': this_sensory,
|
| 158 |
+
# 'pixel_readout': pixel_readout,
|
| 159 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
| 160 |
+
# 'q_weights': aux_features['q_weights'] if aux_features else None,
|
| 161 |
+
# 'p_weights': aux_features['p_weights'] if aux_features else None,
|
| 162 |
+
# 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
|
| 163 |
+
}
|
| 164 |
+
self.aux = aux_output
|
| 165 |
+
|
| 166 |
+
return all_readout_mem
|
| 167 |
+
|
| 168 |
+
def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
|
| 169 |
+
last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
|
| 170 |
+
last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
|
| 171 |
+
"""
|
| 172 |
+
Read from all memory stores and returns a single memory readout tensor for each object
|
| 173 |
+
|
| 174 |
+
pix_feat: (1/2) x C x H x W
|
| 175 |
+
query_key: (1/2) x C^k x H x W
|
| 176 |
+
selection: (1/2) x C^k x H x W
|
| 177 |
+
last_mask: (1/2) x num_objects x H x W (at stride 16)
|
| 178 |
+
return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
|
| 179 |
+
"""
|
| 180 |
+
h, w = pix_feat.shape[-2:]
|
| 181 |
+
bs = pix_feat.shape[0]
|
| 182 |
+
assert query_key.shape[0] == bs
|
| 183 |
+
assert selection.shape[0] == bs
|
| 184 |
+
assert last_mask.shape[0] == bs
|
| 185 |
+
|
| 186 |
+
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
| 187 |
+
|
| 188 |
+
query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
|
| 189 |
+
selection = selection.flatten(start_dim=2) # bs*C^k*HW
|
| 190 |
+
"""
|
| 191 |
+
Compute affinity and perform readout
|
| 192 |
+
"""
|
| 193 |
+
all_readout_mem = {}
|
| 194 |
+
buckets = self.work_mem.buckets
|
| 195 |
+
for bucket_id, bucket in buckets.items():
|
| 196 |
+
if self.use_long_term and self.long_mem.engaged(bucket_id):
|
| 197 |
+
# Use long-term memory
|
| 198 |
+
long_mem_size = self.long_mem.size(bucket_id)
|
| 199 |
+
memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
|
| 200 |
+
-1)
|
| 201 |
+
shrinkage = torch.cat(
|
| 202 |
+
[self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
|
| 203 |
+
|
| 204 |
+
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
| 205 |
+
affinity, usage = do_softmax(similarity,
|
| 206 |
+
top_k=self.top_k,
|
| 207 |
+
inplace=True,
|
| 208 |
+
return_usage=True)
|
| 209 |
+
"""
|
| 210 |
+
Record memory usage for working and long-term memory
|
| 211 |
+
"""
|
| 212 |
+
# ignore the index return for long-term memory
|
| 213 |
+
work_usage = usage[:, long_mem_size:]
|
| 214 |
+
self.work_mem.update_bucket_usage(bucket_id, work_usage)
|
| 215 |
+
|
| 216 |
+
if self.count_long_term_usage:
|
| 217 |
+
# ignore the index return for working memory
|
| 218 |
+
long_usage = usage[:, :long_mem_size]
|
| 219 |
+
self.long_mem.update_bucket_usage(bucket_id, long_usage)
|
| 220 |
+
else:
|
| 221 |
+
# no long-term memory
|
| 222 |
+
memory_key = self.work_mem.key[bucket_id]
|
| 223 |
+
shrinkage = self.work_mem.shrinkage[bucket_id]
|
| 224 |
+
similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask)
|
| 225 |
+
|
| 226 |
+
if self.use_long_term:
|
| 227 |
+
affinity, usage = do_softmax(similarity,
|
| 228 |
+
top_k=self.top_k,
|
| 229 |
+
inplace=True,
|
| 230 |
+
return_usage=True)
|
| 231 |
+
self.work_mem.update_bucket_usage(bucket_id, usage)
|
| 232 |
+
else:
|
| 233 |
+
affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
|
| 234 |
+
|
| 235 |
+
if self.chunk_size < 1:
|
| 236 |
+
object_chunks = [bucket]
|
| 237 |
+
else:
|
| 238 |
+
object_chunks = [
|
| 239 |
+
bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
for objects in object_chunks:
|
| 243 |
+
this_sensory = self._get_sensory_by_ids(objects)
|
| 244 |
+
this_last_mask = self._get_mask_by_ids(last_mask, objects)
|
| 245 |
+
this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
|
| 246 |
+
visual_readout = self._readout(affinity,
|
| 247 |
+
this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w)
|
| 248 |
+
|
| 249 |
+
uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0])
|
| 250 |
+
|
| 251 |
+
if uncert_output is not None:
|
| 252 |
+
uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
|
| 253 |
+
visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob)
|
| 254 |
+
|
| 255 |
+
pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
|
| 256 |
+
this_last_mask)
|
| 257 |
+
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
|
| 258 |
+
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
|
| 259 |
+
for i, obj in enumerate(objects):
|
| 260 |
+
all_readout_mem[obj] = readout_memory[:, i]
|
| 261 |
+
|
| 262 |
+
if self.save_aux:
|
| 263 |
+
aux_output = {
|
| 264 |
+
# 'sensory': this_sensory,
|
| 265 |
+
# 'pixel_readout': pixel_readout,
|
| 266 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
| 267 |
+
# 'q_weights': aux_features['q_weights'] if aux_features else None,
|
| 268 |
+
# 'p_weights': aux_features['p_weights'] if aux_features else None,
|
| 269 |
+
# 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
|
| 270 |
+
}
|
| 271 |
+
self.aux = aux_output
|
| 272 |
+
|
| 273 |
+
return all_readout_mem
|
| 274 |
+
|
| 275 |
+
def add_memory(self,
|
| 276 |
+
key: torch.Tensor,
|
| 277 |
+
shrinkage: torch.Tensor,
|
| 278 |
+
msk_value: torch.Tensor,
|
| 279 |
+
obj_value: torch.Tensor,
|
| 280 |
+
objects: List[int],
|
| 281 |
+
selection: torch.Tensor = None,
|
| 282 |
+
*,
|
| 283 |
+
as_permanent: bool = False) -> None:
|
| 284 |
+
# key: (1/2)*C*H*W
|
| 285 |
+
# msk_value: (1/2)*num_objects*C*H*W
|
| 286 |
+
# obj_value: (1/2)*num_objects*Q*C
|
| 287 |
+
# objects contains a list of object ids corresponding to the objects in msk_value/obj_value
|
| 288 |
+
bs = key.shape[0]
|
| 289 |
+
assert shrinkage.shape[0] == bs
|
| 290 |
+
assert msk_value.shape[0] == bs
|
| 291 |
+
assert obj_value.shape[0] == bs
|
| 292 |
+
|
| 293 |
+
self.engaged = True
|
| 294 |
+
if self.H is None or self.config_stale:
|
| 295 |
+
self.config_stale = False
|
| 296 |
+
self.H, self.W = msk_value.shape[-2:]
|
| 297 |
+
self.HW = self.H * self.W
|
| 298 |
+
# convert from num. frames to num. tokens
|
| 299 |
+
self.max_work_tokens = self.max_mem_frames * self.HW
|
| 300 |
+
if self.use_long_term:
|
| 301 |
+
self.min_work_tokens = self.min_mem_frames * self.HW
|
| 302 |
+
|
| 303 |
+
# key: bs*C*N
|
| 304 |
+
# value: bs*num_objects*C*N
|
| 305 |
+
key = key.flatten(start_dim=2)
|
| 306 |
+
shrinkage = shrinkage.flatten(start_dim=2)
|
| 307 |
+
self.CK = key.shape[1]
|
| 308 |
+
|
| 309 |
+
msk_value = msk_value.flatten(start_dim=3)
|
| 310 |
+
self.CV = msk_value.shape[2]
|
| 311 |
+
|
| 312 |
+
if selection is not None:
|
| 313 |
+
# not used in non-long-term mode
|
| 314 |
+
selection = selection.flatten(start_dim=2)
|
| 315 |
+
|
| 316 |
+
# insert object values into object memory
|
| 317 |
+
for obj_id, obj in enumerate(objects):
|
| 318 |
+
if obj in self.obj_v:
|
| 319 |
+
"""streaming average
|
| 320 |
+
each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
|
| 321 |
+
first embed_dim keeps track of the sum of embeddings
|
| 322 |
+
the last dim keeps the total count
|
| 323 |
+
averaging in done inside the object transformer
|
| 324 |
+
|
| 325 |
+
incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
|
| 326 |
+
self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
|
| 327 |
+
"""
|
| 328 |
+
last_acc = self.obj_v[obj][:, :, -1]
|
| 329 |
+
new_acc = last_acc + obj_value[:, obj_id, :, -1]
|
| 330 |
+
|
| 331 |
+
self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
|
| 332 |
+
obj_value[:, obj_id, :, :-1])
|
| 333 |
+
self.obj_v[obj][:, :, -1] = new_acc
|
| 334 |
+
else:
|
| 335 |
+
self.obj_v[obj] = obj_value[:, obj_id]
|
| 336 |
+
|
| 337 |
+
# convert mask value tensor into a dict for insertion
|
| 338 |
+
msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
|
| 339 |
+
self.work_mem.add(key,
|
| 340 |
+
msk_values,
|
| 341 |
+
shrinkage,
|
| 342 |
+
selection=selection,
|
| 343 |
+
as_permanent=as_permanent)
|
| 344 |
+
|
| 345 |
+
for bucket_id in self.work_mem.buckets.keys():
|
| 346 |
+
# long-term memory cleanup
|
| 347 |
+
if self.use_long_term:
|
| 348 |
+
# Do memory compressed if needed
|
| 349 |
+
if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
|
| 350 |
+
# Remove obsolete features if needed
|
| 351 |
+
if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
|
| 352 |
+
self.num_prototypes):
|
| 353 |
+
self.long_mem.remove_obsolete_features(
|
| 354 |
+
bucket_id,
|
| 355 |
+
self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
|
| 356 |
+
|
| 357 |
+
self.compress_features(bucket_id)
|
| 358 |
+
else:
|
| 359 |
+
# FIFO
|
| 360 |
+
self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
|
| 361 |
+
|
| 362 |
+
def purge_except(self, obj_keep_idx: List[int]) -> None:
|
| 363 |
+
# purge certain objects from the memory except the one listed
|
| 364 |
+
self.work_mem.purge_except(obj_keep_idx)
|
| 365 |
+
if self.use_long_term and self.long_mem.engaged():
|
| 366 |
+
self.long_mem.purge_except(obj_keep_idx)
|
| 367 |
+
self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
|
| 368 |
+
|
| 369 |
+
if not self.work_mem.engaged():
|
| 370 |
+
# everything is removed!
|
| 371 |
+
self.engaged = False
|
| 372 |
+
|
| 373 |
+
def compress_features(self, bucket_id: int) -> None:
|
| 374 |
+
|
| 375 |
+
# perform memory consolidation
|
| 376 |
+
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
| 377 |
+
*self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
|
| 378 |
+
|
| 379 |
+
# remove consolidated working memory
|
| 380 |
+
self.work_mem.sieve_by_range(bucket_id,
|
| 381 |
+
0,
|
| 382 |
+
-self.min_work_tokens,
|
| 383 |
+
min_size=self.min_work_tokens)
|
| 384 |
+
|
| 385 |
+
# add to long-term memory
|
| 386 |
+
self.long_mem.add(prototype_key,
|
| 387 |
+
prototype_value,
|
| 388 |
+
prototype_shrinkage,
|
| 389 |
+
selection=None,
|
| 390 |
+
supposed_bucket_id=bucket_id)
|
| 391 |
+
|
| 392 |
+
def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
|
| 393 |
+
candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
|
| 394 |
+
usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
|
| 395 |
+
# find the indices with max usage
|
| 396 |
+
bs = candidate_key.shape[0]
|
| 397 |
+
assert bs in [1, 2]
|
| 398 |
+
|
| 399 |
+
prototype_key = []
|
| 400 |
+
prototype_selection = []
|
| 401 |
+
for bi in range(bs):
|
| 402 |
+
_, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
|
| 403 |
+
prototype_indices = max_usage_indices.flatten()
|
| 404 |
+
prototype_key.append(candidate_key[bi, :, prototype_indices])
|
| 405 |
+
prototype_selection.append(candidate_selection[bi, :, prototype_indices])
|
| 406 |
+
prototype_key = torch.stack(prototype_key, dim=0)
|
| 407 |
+
prototype_selection = torch.stack(prototype_selection, dim=0)
|
| 408 |
+
"""
|
| 409 |
+
Potentiation step
|
| 410 |
+
"""
|
| 411 |
+
similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
|
| 412 |
+
prototype_selection)
|
| 413 |
+
affinity = do_softmax(similarity)
|
| 414 |
+
|
| 415 |
+
# readout the values
|
| 416 |
+
prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
|
| 417 |
+
|
| 418 |
+
# readout the shrinkage term
|
| 419 |
+
prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
|
| 420 |
+
|
| 421 |
+
return prototype_key, prototype_value, prototype_shrinkage
|
| 422 |
+
|
| 423 |
+
def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
|
| 424 |
+
for obj in ids:
|
| 425 |
+
if obj not in self.sensory:
|
| 426 |
+
# also initializes the sensory memory
|
| 427 |
+
bs, _, h, w = sample_key.shape
|
| 428 |
+
self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
|
| 429 |
+
device=sample_key.device)
|
| 430 |
+
|
| 431 |
+
def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
|
| 432 |
+
# sensory: 1*num_objects*C*H*W
|
| 433 |
+
for obj_id, obj in enumerate(ids):
|
| 434 |
+
self.sensory[obj] = sensory[:, obj_id]
|
| 435 |
+
|
| 436 |
+
def get_sensory(self, ids: List[int]):
|
| 437 |
+
# returns (1/2)*num_objects*C*H*W
|
| 438 |
+
return self._get_sensory_by_ids(ids)
|
| 439 |
+
|
| 440 |
+
def clear_non_permanent_memory(self):
|
| 441 |
+
self.work_mem.clear_non_permanent_memory()
|
| 442 |
+
if self.use_long_term:
|
| 443 |
+
self.long_mem.clear_non_permanent_memory()
|
| 444 |
+
|
| 445 |
+
def clear_sensory_memory(self):
|
| 446 |
+
self.sensory = {}
|
| 447 |
+
|
| 448 |
+
def clear_work_mem(self):
|
| 449 |
+
self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
|
| 450 |
+
save_usage=self.use_long_term)
|
| 451 |
+
|
| 452 |
+
def clear_obj_mem(self):
|
| 453 |
+
self.obj_v = {}
|
hf_space/third_party/matanyone/inference/object_info.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ObjectInfo:
|
| 2 |
+
"""
|
| 3 |
+
Store meta information for an object
|
| 4 |
+
"""
|
| 5 |
+
def __init__(self, id: int):
|
| 6 |
+
self.id = id
|
| 7 |
+
self.poke_count = 0 # count number of detections missed
|
| 8 |
+
|
| 9 |
+
def poke(self) -> None:
|
| 10 |
+
self.poke_count += 1
|
| 11 |
+
|
| 12 |
+
def unpoke(self) -> None:
|
| 13 |
+
self.poke_count = 0
|
| 14 |
+
|
| 15 |
+
def __hash__(self):
|
| 16 |
+
return hash(self.id)
|
| 17 |
+
|
| 18 |
+
def __eq__(self, other):
|
| 19 |
+
if type(other) == int:
|
| 20 |
+
return self.id == other
|
| 21 |
+
return self.id == other.id
|
| 22 |
+
|
| 23 |
+
def __repr__(self):
|
| 24 |
+
return f'(ID: {self.id})'
|
hf_space/third_party/matanyone/inference/object_manager.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from matanyone.inference.object_info import ObjectInfo
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ObjectManager:
|
| 8 |
+
"""
|
| 9 |
+
Object IDs are immutable. The same ID always represent the same object.
|
| 10 |
+
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
|
| 11 |
+
Temporary IDs start from 1.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
|
| 16 |
+
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
|
| 17 |
+
self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
|
| 18 |
+
|
| 19 |
+
self.all_historical_object_ids: List[int] = []
|
| 20 |
+
|
| 21 |
+
def _recompute_obj_id_to_obj_mapping(self) -> None:
|
| 22 |
+
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
|
| 23 |
+
|
| 24 |
+
def add_new_objects(
|
| 25 |
+
self, objects: Union[List[ObjectInfo], ObjectInfo,
|
| 26 |
+
List[int]]) -> (List[int], List[int]):
|
| 27 |
+
if not isinstance(objects, list):
|
| 28 |
+
objects = [objects]
|
| 29 |
+
|
| 30 |
+
corresponding_tmp_ids = []
|
| 31 |
+
corresponding_obj_ids = []
|
| 32 |
+
for obj in objects:
|
| 33 |
+
if isinstance(obj, int):
|
| 34 |
+
obj = ObjectInfo(id=obj)
|
| 35 |
+
|
| 36 |
+
if obj in self.obj_to_tmp_id:
|
| 37 |
+
# old object
|
| 38 |
+
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
|
| 39 |
+
corresponding_obj_ids.append(obj.id)
|
| 40 |
+
else:
|
| 41 |
+
# new object
|
| 42 |
+
new_obj = ObjectInfo(id=obj.id)
|
| 43 |
+
|
| 44 |
+
# new object
|
| 45 |
+
new_tmp_id = len(self.obj_to_tmp_id) + 1
|
| 46 |
+
self.obj_to_tmp_id[new_obj] = new_tmp_id
|
| 47 |
+
self.tmp_id_to_obj[new_tmp_id] = new_obj
|
| 48 |
+
self.all_historical_object_ids.append(new_obj.id)
|
| 49 |
+
corresponding_tmp_ids.append(new_tmp_id)
|
| 50 |
+
corresponding_obj_ids.append(new_obj.id)
|
| 51 |
+
|
| 52 |
+
self._recompute_obj_id_to_obj_mapping()
|
| 53 |
+
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
|
| 54 |
+
return corresponding_tmp_ids, corresponding_obj_ids
|
| 55 |
+
|
| 56 |
+
def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
|
| 57 |
+
# delete an object or a list of objects
|
| 58 |
+
# re-sort the tmp ids
|
| 59 |
+
if isinstance(obj_ids_to_remove, int):
|
| 60 |
+
obj_ids_to_remove = [obj_ids_to_remove]
|
| 61 |
+
|
| 62 |
+
new_tmp_id = 1
|
| 63 |
+
total_num_id = len(self.obj_to_tmp_id)
|
| 64 |
+
|
| 65 |
+
local_obj_to_tmp_id = {}
|
| 66 |
+
local_tmp_to_obj_id = {}
|
| 67 |
+
|
| 68 |
+
for tmp_iter in range(1, total_num_id + 1):
|
| 69 |
+
obj = self.tmp_id_to_obj[tmp_iter]
|
| 70 |
+
if obj.id not in obj_ids_to_remove:
|
| 71 |
+
local_obj_to_tmp_id[obj] = new_tmp_id
|
| 72 |
+
local_tmp_to_obj_id[new_tmp_id] = obj
|
| 73 |
+
new_tmp_id += 1
|
| 74 |
+
|
| 75 |
+
self.obj_to_tmp_id = local_obj_to_tmp_id
|
| 76 |
+
self.tmp_id_to_obj = local_tmp_to_obj_id
|
| 77 |
+
self._recompute_obj_id_to_obj_mapping()
|
| 78 |
+
|
| 79 |
+
def purge_inactive_objects(self,
|
| 80 |
+
max_missed_detection_count: int) -> (bool, List[int], List[int]):
|
| 81 |
+
# remove tmp ids of objects that are removed
|
| 82 |
+
obj_id_to_be_deleted = []
|
| 83 |
+
tmp_id_to_be_deleted = []
|
| 84 |
+
tmp_id_to_keep = []
|
| 85 |
+
obj_id_to_keep = []
|
| 86 |
+
|
| 87 |
+
for obj in self.obj_to_tmp_id:
|
| 88 |
+
if obj.poke_count > max_missed_detection_count:
|
| 89 |
+
obj_id_to_be_deleted.append(obj.id)
|
| 90 |
+
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
|
| 91 |
+
else:
|
| 92 |
+
tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
|
| 93 |
+
obj_id_to_keep.append(obj.id)
|
| 94 |
+
|
| 95 |
+
purge_activated = len(obj_id_to_be_deleted) > 0
|
| 96 |
+
if purge_activated:
|
| 97 |
+
self.delete_objects(obj_id_to_be_deleted)
|
| 98 |
+
return purge_activated, tmp_id_to_keep, obj_id_to_keep
|
| 99 |
+
|
| 100 |
+
def tmp_to_obj_cls(self, mask) -> torch.Tensor:
|
| 101 |
+
# remap tmp id cls representation to the true object id representation
|
| 102 |
+
new_mask = torch.zeros_like(mask)
|
| 103 |
+
for tmp_id, obj in self.tmp_id_to_obj.items():
|
| 104 |
+
new_mask[mask == tmp_id] = obj.id
|
| 105 |
+
return new_mask
|
| 106 |
+
|
| 107 |
+
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
|
| 108 |
+
# returns the mapping in a dict format for saving it with pickle
|
| 109 |
+
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
|
| 110 |
+
|
| 111 |
+
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
|
| 112 |
+
# turns a dict indexed by obj id into a tensor, ordered by tmp IDs
|
| 113 |
+
output = []
|
| 114 |
+
for _, obj in self.tmp_id_to_obj.items():
|
| 115 |
+
if obj.id not in obj_dict:
|
| 116 |
+
raise NotImplementedError
|
| 117 |
+
output.append(obj_dict[obj.id])
|
| 118 |
+
output = torch.stack(output, dim=dim)
|
| 119 |
+
return output
|
| 120 |
+
|
| 121 |
+
def make_one_hot(self, cls_mask) -> torch.Tensor:
|
| 122 |
+
output = []
|
| 123 |
+
for _, obj in self.tmp_id_to_obj.items():
|
| 124 |
+
output.append(cls_mask == obj.id)
|
| 125 |
+
if len(output) == 0:
|
| 126 |
+
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
|
| 127 |
+
else:
|
| 128 |
+
output = torch.stack(output, dim=0)
|
| 129 |
+
return output
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def all_obj_ids(self) -> List[int]:
|
| 133 |
+
return [k.id for k in self.obj_to_tmp_id]
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def num_obj(self) -> int:
|
| 137 |
+
return len(self.obj_to_tmp_id)
|
| 138 |
+
|
| 139 |
+
def has_all(self, objects: List[int]) -> bool:
|
| 140 |
+
for obj in objects:
|
| 141 |
+
if obj not in self.obj_to_tmp_id:
|
| 142 |
+
return False
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
def find_object_by_id(self, obj_id) -> ObjectInfo:
|
| 146 |
+
return self.obj_id_to_obj[obj_id]
|
| 147 |
+
|
| 148 |
+
def find_tmp_by_id(self, obj_id) -> int:
|
| 149 |
+
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
|
hf_space/third_party/matanyone/inference/utils/__init__.py
ADDED
|
File without changes
|
hf_space/third_party/matanyone/inference/utils/args_utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
|
| 4 |
+
log = logging.getLogger()
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_dataset_cfg(cfg: DictConfig):
|
| 8 |
+
dataset_name = cfg.dataset
|
| 9 |
+
data_cfg = cfg.datasets[dataset_name]
|
| 10 |
+
|
| 11 |
+
potential_overrides = [
|
| 12 |
+
'image_directory',
|
| 13 |
+
'mask_directory',
|
| 14 |
+
'json_directory',
|
| 15 |
+
'size',
|
| 16 |
+
'save_all',
|
| 17 |
+
'use_all_masks',
|
| 18 |
+
'use_long_term',
|
| 19 |
+
'mem_every',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
for override in potential_overrides:
|
| 23 |
+
if cfg[override] is not None:
|
| 24 |
+
log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
|
| 25 |
+
data_cfg[override] = cfg[override]
|
| 26 |
+
# escalte all potential overrides to the top-level config
|
| 27 |
+
if override in data_cfg:
|
| 28 |
+
cfg[override] = data_cfg[override]
|
| 29 |
+
|
| 30 |
+
return data_cfg
|
hf_space/third_party/matanyone/model/__init__.py
ADDED
|
File without changes
|
hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc
ADDED
|
Binary file (5.64 kB). View file
|
|
|
hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc
ADDED
|
Binary file (19.8 kB). View file
|
|
|
hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc
ADDED
|
Binary file (7.43 kB). View file
|
|
|
hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
hf_space/third_party/matanyone/model/aux_modules.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
For computing auxiliary outputs for auxiliary losses
|
| 3 |
+
"""
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from matanyone.model.group_modules import GConv2d
|
| 10 |
+
from matanyone.utils.tensor_utils import aggregate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LinearPredictor(nn.Module):
|
| 14 |
+
def __init__(self, x_dim: int, pix_dim: int):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
|
| 17 |
+
|
| 18 |
+
def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
# pixel_feat: B*pix_dim*H*W
|
| 20 |
+
# x: B*num_objects*x_dim*H*W
|
| 21 |
+
num_objects = x.shape[1]
|
| 22 |
+
x = self.projection(x)
|
| 23 |
+
|
| 24 |
+
pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
|
| 25 |
+
logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
|
| 26 |
+
return logits
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DirectPredictor(nn.Module):
|
| 30 |
+
def __init__(self, x_dim: int):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.projection = GConv2d(x_dim, 1, kernel_size=1)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
# x: B*num_objects*x_dim*H*W
|
| 36 |
+
logits = self.projection(x).squeeze(2)
|
| 37 |
+
return logits
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AuxComputer(nn.Module):
|
| 41 |
+
def __init__(self, cfg: DictConfig):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
use_sensory_aux = cfg.model.aux_loss.sensory.enabled
|
| 45 |
+
self.use_query_aux = cfg.model.aux_loss.query.enabled
|
| 46 |
+
self.use_sensory_aux = use_sensory_aux
|
| 47 |
+
|
| 48 |
+
sensory_dim = cfg.model.sensory_dim
|
| 49 |
+
embed_dim = cfg.model.embed_dim
|
| 50 |
+
|
| 51 |
+
if use_sensory_aux:
|
| 52 |
+
self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
|
| 53 |
+
|
| 54 |
+
def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
prob = torch.sigmoid(logits)
|
| 56 |
+
if selector is not None:
|
| 57 |
+
prob = prob * selector
|
| 58 |
+
logits = aggregate(prob, dim=1)
|
| 59 |
+
return logits
|
| 60 |
+
|
| 61 |
+
def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
|
| 62 |
+
selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
|
| 63 |
+
sensory = aux_input['sensory']
|
| 64 |
+
q_logits = aux_input['q_logits']
|
| 65 |
+
|
| 66 |
+
aux_output = {}
|
| 67 |
+
aux_output['attn_mask'] = aux_input['attn_mask']
|
| 68 |
+
|
| 69 |
+
if self.use_sensory_aux:
|
| 70 |
+
# B*num_objects*H*W
|
| 71 |
+
logits = self.sensory_aux(pix_feat, sensory)
|
| 72 |
+
aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
|
| 73 |
+
if self.use_query_aux:
|
| 74 |
+
# B*num_objects*num_levels*H*W
|
| 75 |
+
aux_output['q_logits'] = self._aggregate_with_selector(
|
| 76 |
+
torch.stack(q_logits, dim=2),
|
| 77 |
+
selector.unsqueeze(2) if selector is not None else None)
|
| 78 |
+
|
| 79 |
+
return aux_output
|
| 80 |
+
|
| 81 |
+
def compute_mask(self, aux_input: Dict[str, torch.Tensor],
|
| 82 |
+
selector: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 83 |
+
# sensory = aux_input['sensory']
|
| 84 |
+
q_logits = aux_input['q_logits']
|
| 85 |
+
|
| 86 |
+
aux_output = {}
|
| 87 |
+
|
| 88 |
+
# B*num_objects*num_levels*H*W
|
| 89 |
+
aux_output['q_logits'] = self._aggregate_with_selector(
|
| 90 |
+
torch.stack(q_logits, dim=2),
|
| 91 |
+
selector.unsqueeze(2) if selector is not None else None)
|
| 92 |
+
|
| 93 |
+
return aux_output
|
hf_space/third_party/matanyone/model/big_modules.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
big_modules.py - This file stores higher-level network blocks.
|
| 3 |
+
|
| 4 |
+
x - usually denotes features that are shared between objects.
|
| 5 |
+
g - usually denotes features that are not shared between objects
|
| 6 |
+
with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
|
| 7 |
+
|
| 8 |
+
The trailing number of a variable usually denotes the stride
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Iterable
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
|
| 18 |
+
from matanyone.model.utils import resnet
|
| 19 |
+
from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
|
| 20 |
+
|
| 21 |
+
class UncertPred(nn.Module):
|
| 22 |
+
def __init__(self, model_cfg: DictConfig):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
|
| 25 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 26 |
+
self.relu = nn.ReLU(inplace=True)
|
| 27 |
+
self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
|
| 28 |
+
self.bn2 = nn.BatchNorm2d(32)
|
| 29 |
+
self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
|
| 30 |
+
|
| 31 |
+
def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
|
| 32 |
+
last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
|
| 33 |
+
x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
|
| 34 |
+
x = self.conv1x1_v2(x)
|
| 35 |
+
x = self.bn1(x)
|
| 36 |
+
x = self.relu(x)
|
| 37 |
+
x = self.conv3x3(x)
|
| 38 |
+
x = self.bn2(x)
|
| 39 |
+
x = self.relu(x)
|
| 40 |
+
x = self.conv3x3_out(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
# override the default train() to freeze BN statistics
|
| 44 |
+
def train(self, mode=True):
|
| 45 |
+
self.training = False
|
| 46 |
+
for module in self.children():
|
| 47 |
+
module.train(False)
|
| 48 |
+
return self
|
| 49 |
+
|
| 50 |
+
class PixelEncoder(nn.Module):
|
| 51 |
+
def __init__(self, model_cfg: DictConfig):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
|
| 55 |
+
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
|
| 56 |
+
# else default to True
|
| 57 |
+
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
|
| 58 |
+
if self.is_resnet:
|
| 59 |
+
if model_cfg.pixel_encoder.type == 'resnet18':
|
| 60 |
+
network = resnet.resnet18(pretrained=is_pretrained_resnet)
|
| 61 |
+
elif model_cfg.pixel_encoder.type == 'resnet50':
|
| 62 |
+
network = resnet.resnet50(pretrained=is_pretrained_resnet)
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
self.conv1 = network.conv1
|
| 66 |
+
self.bn1 = network.bn1
|
| 67 |
+
self.relu = network.relu
|
| 68 |
+
self.maxpool = network.maxpool
|
| 69 |
+
|
| 70 |
+
self.res2 = network.layer1
|
| 71 |
+
self.layer2 = network.layer2
|
| 72 |
+
self.layer3 = network.layer3
|
| 73 |
+
else:
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
| 77 |
+
f1 = x
|
| 78 |
+
x = self.conv1(x)
|
| 79 |
+
x = self.bn1(x)
|
| 80 |
+
x = self.relu(x)
|
| 81 |
+
f2 = x
|
| 82 |
+
x = self.maxpool(x)
|
| 83 |
+
f4 = self.res2(x)
|
| 84 |
+
f8 = self.layer2(f4)
|
| 85 |
+
f16 = self.layer3(f8)
|
| 86 |
+
|
| 87 |
+
return f16, f8, f4, f2, f1
|
| 88 |
+
|
| 89 |
+
# override the default train() to freeze BN statistics
|
| 90 |
+
def train(self, mode=True):
|
| 91 |
+
self.training = False
|
| 92 |
+
for module in self.children():
|
| 93 |
+
module.train(False)
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class KeyProjection(nn.Module):
|
| 98 |
+
def __init__(self, model_cfg: DictConfig):
|
| 99 |
+
super().__init__()
|
| 100 |
+
in_dim = model_cfg.pixel_encoder.ms_dims[0]
|
| 101 |
+
mid_dim = model_cfg.pixel_dim
|
| 102 |
+
key_dim = model_cfg.key_dim
|
| 103 |
+
|
| 104 |
+
self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
|
| 105 |
+
self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
|
| 106 |
+
# shrinkage
|
| 107 |
+
self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
|
| 108 |
+
# selection
|
| 109 |
+
self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
|
| 110 |
+
|
| 111 |
+
nn.init.orthogonal_(self.key_proj.weight.data)
|
| 112 |
+
nn.init.zeros_(self.key_proj.bias.data)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor, *, need_s: bool,
|
| 115 |
+
need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
|
| 116 |
+
x = self.pix_feat_proj(x)
|
| 117 |
+
shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
|
| 118 |
+
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
|
| 119 |
+
|
| 120 |
+
return self.key_proj(x), shrinkage, selection
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MaskEncoder(nn.Module):
|
| 124 |
+
def __init__(self, model_cfg: DictConfig, single_object=False):
|
| 125 |
+
super().__init__()
|
| 126 |
+
pixel_dim = model_cfg.pixel_dim
|
| 127 |
+
value_dim = model_cfg.value_dim
|
| 128 |
+
sensory_dim = model_cfg.sensory_dim
|
| 129 |
+
final_dim = model_cfg.mask_encoder.final_dim
|
| 130 |
+
|
| 131 |
+
self.single_object = single_object
|
| 132 |
+
extra_dim = 1 if single_object else 2
|
| 133 |
+
|
| 134 |
+
# if model_cfg.pretrained_resnet is set in the model_cfg we get the value
|
| 135 |
+
# else default to True
|
| 136 |
+
is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
|
| 137 |
+
if model_cfg.mask_encoder.type == 'resnet18':
|
| 138 |
+
network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
|
| 139 |
+
elif model_cfg.mask_encoder.type == 'resnet50':
|
| 140 |
+
network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
|
| 141 |
+
else:
|
| 142 |
+
raise NotImplementedError
|
| 143 |
+
self.conv1 = network.conv1
|
| 144 |
+
self.bn1 = network.bn1
|
| 145 |
+
self.relu = network.relu
|
| 146 |
+
self.maxpool = network.maxpool
|
| 147 |
+
|
| 148 |
+
self.layer1 = network.layer1
|
| 149 |
+
self.layer2 = network.layer2
|
| 150 |
+
self.layer3 = network.layer3
|
| 151 |
+
|
| 152 |
+
self.distributor = MainToGroupDistributor()
|
| 153 |
+
self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
|
| 154 |
+
|
| 155 |
+
self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
|
| 156 |
+
|
| 157 |
+
def forward(self,
|
| 158 |
+
image: torch.Tensor,
|
| 159 |
+
pix_feat: torch.Tensor,
|
| 160 |
+
sensory: torch.Tensor,
|
| 161 |
+
masks: torch.Tensor,
|
| 162 |
+
others: torch.Tensor,
|
| 163 |
+
*,
|
| 164 |
+
deep_update: bool = True,
|
| 165 |
+
chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
|
| 166 |
+
# ms_features are from the key encoder
|
| 167 |
+
# we only use the first one (lowest resolution), following XMem
|
| 168 |
+
if self.single_object:
|
| 169 |
+
g = masks.unsqueeze(2)
|
| 170 |
+
else:
|
| 171 |
+
g = torch.stack([masks, others], dim=2)
|
| 172 |
+
|
| 173 |
+
g = self.distributor(image, g)
|
| 174 |
+
|
| 175 |
+
batch_size, num_objects = g.shape[:2]
|
| 176 |
+
if chunk_size < 1 or chunk_size >= num_objects:
|
| 177 |
+
chunk_size = num_objects
|
| 178 |
+
fast_path = True
|
| 179 |
+
new_sensory = sensory
|
| 180 |
+
else:
|
| 181 |
+
if deep_update:
|
| 182 |
+
new_sensory = torch.empty_like(sensory)
|
| 183 |
+
else:
|
| 184 |
+
new_sensory = sensory
|
| 185 |
+
fast_path = False
|
| 186 |
+
|
| 187 |
+
# chunk-by-chunk inference
|
| 188 |
+
all_g = []
|
| 189 |
+
for i in range(0, num_objects, chunk_size):
|
| 190 |
+
if fast_path:
|
| 191 |
+
g_chunk = g
|
| 192 |
+
else:
|
| 193 |
+
g_chunk = g[:, i:i + chunk_size]
|
| 194 |
+
actual_chunk_size = g_chunk.shape[1]
|
| 195 |
+
g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
|
| 196 |
+
|
| 197 |
+
g_chunk = self.conv1(g_chunk)
|
| 198 |
+
g_chunk = self.bn1(g_chunk) # 1/2, 64
|
| 199 |
+
g_chunk = self.maxpool(g_chunk) # 1/4, 64
|
| 200 |
+
g_chunk = self.relu(g_chunk)
|
| 201 |
+
|
| 202 |
+
g_chunk = self.layer1(g_chunk) # 1/4
|
| 203 |
+
g_chunk = self.layer2(g_chunk) # 1/8
|
| 204 |
+
g_chunk = self.layer3(g_chunk) # 1/16
|
| 205 |
+
|
| 206 |
+
g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
|
| 207 |
+
g_chunk = self.fuser(pix_feat, g_chunk)
|
| 208 |
+
all_g.append(g_chunk)
|
| 209 |
+
if deep_update:
|
| 210 |
+
if fast_path:
|
| 211 |
+
new_sensory = self.sensory_update(g_chunk, sensory)
|
| 212 |
+
else:
|
| 213 |
+
new_sensory[:, i:i + chunk_size] = self.sensory_update(
|
| 214 |
+
g_chunk, sensory[:, i:i + chunk_size])
|
| 215 |
+
g = torch.cat(all_g, dim=1)
|
| 216 |
+
|
| 217 |
+
return g, new_sensory
|
| 218 |
+
|
| 219 |
+
# override the default train() to freeze BN statistics
|
| 220 |
+
def train(self, mode=True):
|
| 221 |
+
self.training = False
|
| 222 |
+
for module in self.children():
|
| 223 |
+
module.train(False)
|
| 224 |
+
return self
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class PixelFeatureFuser(nn.Module):
|
| 228 |
+
def __init__(self, model_cfg: DictConfig, single_object=False):
|
| 229 |
+
super().__init__()
|
| 230 |
+
value_dim = model_cfg.value_dim
|
| 231 |
+
sensory_dim = model_cfg.sensory_dim
|
| 232 |
+
pixel_dim = model_cfg.pixel_dim
|
| 233 |
+
embed_dim = model_cfg.embed_dim
|
| 234 |
+
self.single_object = single_object
|
| 235 |
+
|
| 236 |
+
self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
|
| 237 |
+
if self.single_object:
|
| 238 |
+
self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
|
| 239 |
+
else:
|
| 240 |
+
self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
|
| 241 |
+
|
| 242 |
+
def forward(self,
|
| 243 |
+
pix_feat: torch.Tensor,
|
| 244 |
+
pixel_memory: torch.Tensor,
|
| 245 |
+
sensory_memory: torch.Tensor,
|
| 246 |
+
last_mask: torch.Tensor,
|
| 247 |
+
last_others: torch.Tensor,
|
| 248 |
+
*,
|
| 249 |
+
chunk_size: int = -1) -> torch.Tensor:
|
| 250 |
+
batch_size, num_objects = pixel_memory.shape[:2]
|
| 251 |
+
|
| 252 |
+
if self.single_object:
|
| 253 |
+
last_mask = last_mask.unsqueeze(2)
|
| 254 |
+
else:
|
| 255 |
+
last_mask = torch.stack([last_mask, last_others], dim=2)
|
| 256 |
+
|
| 257 |
+
if chunk_size < 1:
|
| 258 |
+
chunk_size = num_objects
|
| 259 |
+
|
| 260 |
+
# chunk-by-chunk inference
|
| 261 |
+
all_p16 = []
|
| 262 |
+
for i in range(0, num_objects, chunk_size):
|
| 263 |
+
sensory_readout = self.sensory_compress(
|
| 264 |
+
torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
|
| 265 |
+
p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
|
| 266 |
+
p16 = self.fuser(pix_feat, p16)
|
| 267 |
+
all_p16.append(p16)
|
| 268 |
+
p16 = torch.cat(all_p16, dim=1)
|
| 269 |
+
|
| 270 |
+
return p16
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class MaskDecoder(nn.Module):
|
| 274 |
+
def __init__(self, model_cfg: DictConfig):
|
| 275 |
+
super().__init__()
|
| 276 |
+
embed_dim = model_cfg.embed_dim
|
| 277 |
+
sensory_dim = model_cfg.sensory_dim
|
| 278 |
+
ms_image_dims = model_cfg.pixel_encoder.ms_dims
|
| 279 |
+
up_dims = model_cfg.mask_decoder.up_dims
|
| 280 |
+
|
| 281 |
+
assert embed_dim == up_dims[0]
|
| 282 |
+
|
| 283 |
+
self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
|
| 284 |
+
sensory_dim)
|
| 285 |
+
|
| 286 |
+
self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
|
| 287 |
+
self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
|
| 288 |
+
self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
|
| 289 |
+
# newly add for alpha matte
|
| 290 |
+
self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
|
| 291 |
+
self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
|
| 292 |
+
|
| 293 |
+
self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
|
| 294 |
+
self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
|
| 295 |
+
|
| 296 |
+
def forward(self,
|
| 297 |
+
ms_image_feat: Iterable[torch.Tensor],
|
| 298 |
+
memory_readout: torch.Tensor,
|
| 299 |
+
sensory: torch.Tensor,
|
| 300 |
+
*,
|
| 301 |
+
chunk_size: int = -1,
|
| 302 |
+
update_sensory: bool = True,
|
| 303 |
+
seg_pass: bool = False,
|
| 304 |
+
last_mask=None,
|
| 305 |
+
sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
|
| 306 |
+
|
| 307 |
+
batch_size, num_objects = memory_readout.shape[:2]
|
| 308 |
+
f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
|
| 309 |
+
if chunk_size < 1 or chunk_size >= num_objects:
|
| 310 |
+
chunk_size = num_objects
|
| 311 |
+
fast_path = True
|
| 312 |
+
new_sensory = sensory
|
| 313 |
+
else:
|
| 314 |
+
if update_sensory:
|
| 315 |
+
new_sensory = torch.empty_like(sensory)
|
| 316 |
+
else:
|
| 317 |
+
new_sensory = sensory
|
| 318 |
+
fast_path = False
|
| 319 |
+
|
| 320 |
+
# chunk-by-chunk inference
|
| 321 |
+
all_logits = []
|
| 322 |
+
for i in range(0, num_objects, chunk_size):
|
| 323 |
+
if fast_path:
|
| 324 |
+
p16 = memory_readout
|
| 325 |
+
else:
|
| 326 |
+
p16 = memory_readout[:, i:i + chunk_size]
|
| 327 |
+
actual_chunk_size = p16.shape[1]
|
| 328 |
+
|
| 329 |
+
p8 = self.up_16_8(p16, f8)
|
| 330 |
+
p4 = self.up_8_4(p8, f4)
|
| 331 |
+
p2 = self.up_4_2(p4, f2)
|
| 332 |
+
p1 = self.up_2_1(p2, f1)
|
| 333 |
+
with torch.amp.autocast("cuda",enabled=False):
|
| 334 |
+
if seg_pass:
|
| 335 |
+
if last_mask is not None:
|
| 336 |
+
res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
| 337 |
+
if sigmoid_residual:
|
| 338 |
+
res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
|
| 339 |
+
logits = last_mask + res
|
| 340 |
+
else:
|
| 341 |
+
logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
| 342 |
+
else:
|
| 343 |
+
if last_mask is not None:
|
| 344 |
+
res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
| 345 |
+
if sigmoid_residual:
|
| 346 |
+
res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
|
| 347 |
+
logits = last_mask + res
|
| 348 |
+
else:
|
| 349 |
+
logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
| 350 |
+
## SensoryUpdater_fullscale
|
| 351 |
+
if update_sensory:
|
| 352 |
+
p1 = torch.cat(
|
| 353 |
+
[p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
|
| 354 |
+
if fast_path:
|
| 355 |
+
new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
|
| 356 |
+
else:
|
| 357 |
+
new_sensory[:,
|
| 358 |
+
i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
|
| 359 |
+
sensory[:,
|
| 360 |
+
i:i + chunk_size])
|
| 361 |
+
all_logits.append(logits)
|
| 362 |
+
logits = torch.cat(all_logits, dim=0)
|
| 363 |
+
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
|
| 364 |
+
|
| 365 |
+
return new_sensory, logits
|
hf_space/third_party/matanyone/model/channel_attn.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CAResBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.residual = residual
|
| 11 |
+
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
|
| 12 |
+
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
|
| 13 |
+
|
| 14 |
+
t = int((abs(math.log2(out_dim)) + 1) // 2)
|
| 15 |
+
k = t if t % 2 else t + 1
|
| 16 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 17 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
|
| 18 |
+
|
| 19 |
+
if self.residual:
|
| 20 |
+
if in_dim == out_dim:
|
| 21 |
+
self.downsample = nn.Identity()
|
| 22 |
+
else:
|
| 23 |
+
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
r = x
|
| 27 |
+
x = self.conv1(F.relu(x))
|
| 28 |
+
x = self.conv2(F.relu(x))
|
| 29 |
+
|
| 30 |
+
b, c = x.shape[:2]
|
| 31 |
+
w = self.pool(x).view(b, 1, c)
|
| 32 |
+
w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
|
| 33 |
+
|
| 34 |
+
if self.residual:
|
| 35 |
+
x = x * w + self.downsample(r)
|
| 36 |
+
else:
|
| 37 |
+
x = x * w
|
| 38 |
+
|
| 39 |
+
return x
|
hf_space/third_party/matanyone/model/group_modules.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from matanyone.model.channel_attn import CAResBlock
|
| 6 |
+
|
| 7 |
+
def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
|
| 8 |
+
align_corners: bool) -> torch.Tensor:
|
| 9 |
+
batch_size, num_objects = g.shape[:2]
|
| 10 |
+
g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
|
| 11 |
+
scale_factor=ratio,
|
| 12 |
+
mode=mode,
|
| 13 |
+
align_corners=align_corners)
|
| 14 |
+
g = g.view(batch_size, num_objects, *g.shape[1:])
|
| 15 |
+
return g
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def upsample_groups(g: torch.Tensor,
|
| 19 |
+
ratio: float = 2,
|
| 20 |
+
mode: str = 'bilinear',
|
| 21 |
+
align_corners: bool = False) -> torch.Tensor:
|
| 22 |
+
return interpolate_groups(g, ratio, mode, align_corners)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def downsample_groups(g: torch.Tensor,
|
| 26 |
+
ratio: float = 1 / 2,
|
| 27 |
+
mode: str = 'area',
|
| 28 |
+
align_corners: bool = None) -> torch.Tensor:
|
| 29 |
+
return interpolate_groups(g, ratio, mode, align_corners)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class GConv2d(nn.Conv2d):
|
| 33 |
+
def forward(self, g: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
batch_size, num_objects = g.shape[:2]
|
| 35 |
+
g = super().forward(g.flatten(start_dim=0, end_dim=1))
|
| 36 |
+
return g.view(batch_size, num_objects, *g.shape[1:])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class GroupResBlock(nn.Module):
|
| 40 |
+
def __init__(self, in_dim: int, out_dim: int):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
if in_dim == out_dim:
|
| 44 |
+
self.downsample = nn.Identity()
|
| 45 |
+
else:
|
| 46 |
+
self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
|
| 47 |
+
|
| 48 |
+
self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
|
| 49 |
+
self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
|
| 50 |
+
|
| 51 |
+
def forward(self, g: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
out_g = self.conv1(F.relu(g))
|
| 53 |
+
out_g = self.conv2(F.relu(out_g))
|
| 54 |
+
|
| 55 |
+
g = self.downsample(g)
|
| 56 |
+
|
| 57 |
+
return out_g + g
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class MainToGroupDistributor(nn.Module):
|
| 61 |
+
def __init__(self,
|
| 62 |
+
x_transform: Optional[nn.Module] = None,
|
| 63 |
+
g_transform: Optional[nn.Module] = None,
|
| 64 |
+
method: str = 'cat',
|
| 65 |
+
reverse_order: bool = False):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.x_transform = x_transform
|
| 69 |
+
self.g_transform = g_transform
|
| 70 |
+
self.method = method
|
| 71 |
+
self.reverse_order = reverse_order
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
|
| 74 |
+
num_objects = g.shape[1]
|
| 75 |
+
|
| 76 |
+
if self.x_transform is not None:
|
| 77 |
+
x = self.x_transform(x)
|
| 78 |
+
|
| 79 |
+
if self.g_transform is not None:
|
| 80 |
+
g = self.g_transform(g)
|
| 81 |
+
|
| 82 |
+
if not skip_expand:
|
| 83 |
+
x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
|
| 84 |
+
if self.method == 'cat':
|
| 85 |
+
if self.reverse_order:
|
| 86 |
+
g = torch.cat([g, x], 2)
|
| 87 |
+
else:
|
| 88 |
+
g = torch.cat([x, g], 2)
|
| 89 |
+
elif self.method == 'add':
|
| 90 |
+
g = x + g
|
| 91 |
+
elif self.method == 'mulcat':
|
| 92 |
+
g = torch.cat([x * g, g], dim=2)
|
| 93 |
+
elif self.method == 'muladd':
|
| 94 |
+
g = x * g + g
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
return g
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class GroupFeatureFusionBlock(nn.Module):
|
| 102 |
+
def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
|
| 106 |
+
g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
|
| 107 |
+
|
| 108 |
+
self.distributor = MainToGroupDistributor(x_transform=x_transform,
|
| 109 |
+
g_transform=g_transform,
|
| 110 |
+
method='add')
|
| 111 |
+
self.block1 = CAResBlock(out_dim, out_dim)
|
| 112 |
+
self.block2 = CAResBlock(out_dim, out_dim)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
batch_size, num_objects = g.shape[:2]
|
| 116 |
+
|
| 117 |
+
g = self.distributor(x, g)
|
| 118 |
+
|
| 119 |
+
g = g.flatten(start_dim=0, end_dim=1)
|
| 120 |
+
|
| 121 |
+
g = self.block1(g)
|
| 122 |
+
g = self.block2(g)
|
| 123 |
+
|
| 124 |
+
g = g.view(batch_size, num_objects, *g.shape[1:])
|
| 125 |
+
|
| 126 |
+
return g
|
hf_space/third_party/matanyone/model/matanyone.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Iterable, Tuple
|
| 2 |
+
import logging
|
| 3 |
+
from omegaconf import DictConfig
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
+
|
| 10 |
+
from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
|
| 11 |
+
from matanyone.model.aux_modules import AuxComputer
|
| 12 |
+
from matanyone.model.utils.memory_utils import get_affinity, readout
|
| 13 |
+
from matanyone.model.transformer.object_transformer import QueryTransformer
|
| 14 |
+
from matanyone.model.transformer.object_summarizer import ObjectSummarizer
|
| 15 |
+
from matanyone.utils.tensor_utils import aggregate
|
| 16 |
+
|
| 17 |
+
log = logging.getLogger()
|
| 18 |
+
class MatAnyone(nn.Module,
|
| 19 |
+
PyTorchModelHubMixin,
|
| 20 |
+
library_name="matanyone",
|
| 21 |
+
repo_url="https://github.com/pq-yang/MatAnyone",
|
| 22 |
+
coders={
|
| 23 |
+
DictConfig: (
|
| 24 |
+
lambda x: OmegaConf.to_container(x),
|
| 25 |
+
lambda data: OmegaConf.create(data),
|
| 26 |
+
)
|
| 27 |
+
},
|
| 28 |
+
):
|
| 29 |
+
|
| 30 |
+
def __init__(self, cfg: DictConfig, *, single_object=False):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.cfg = cfg
|
| 33 |
+
model_cfg = cfg.model
|
| 34 |
+
self.ms_dims = model_cfg.pixel_encoder.ms_dims
|
| 35 |
+
self.key_dim = model_cfg.key_dim
|
| 36 |
+
self.value_dim = model_cfg.value_dim
|
| 37 |
+
self.sensory_dim = model_cfg.sensory_dim
|
| 38 |
+
self.pixel_dim = model_cfg.pixel_dim
|
| 39 |
+
self.embed_dim = model_cfg.embed_dim
|
| 40 |
+
self.single_object = single_object
|
| 41 |
+
|
| 42 |
+
log.info(f'Single object: {self.single_object}')
|
| 43 |
+
|
| 44 |
+
self.pixel_encoder = PixelEncoder(model_cfg)
|
| 45 |
+
self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
|
| 46 |
+
self.key_proj = KeyProjection(model_cfg)
|
| 47 |
+
self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
|
| 48 |
+
self.mask_decoder = MaskDecoder(model_cfg)
|
| 49 |
+
self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
|
| 50 |
+
self.object_transformer = QueryTransformer(model_cfg)
|
| 51 |
+
self.object_summarizer = ObjectSummarizer(model_cfg)
|
| 52 |
+
self.aux_computer = AuxComputer(cfg)
|
| 53 |
+
self.temp_sparity = UncertPred(model_cfg)
|
| 54 |
+
|
| 55 |
+
self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
|
| 56 |
+
self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
|
| 57 |
+
|
| 58 |
+
def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
# for each object, return the sum of masks of all other objects
|
| 60 |
+
if self.single_object:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
num_objects = masks.shape[1]
|
| 64 |
+
if num_objects >= 1:
|
| 65 |
+
others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
|
| 66 |
+
else:
|
| 67 |
+
others = torch.zeros_like(masks)
|
| 68 |
+
return others
|
| 69 |
+
|
| 70 |
+
def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
|
| 71 |
+
logits = self.temp_sparity(last_frame_feat=last_pix_feat,
|
| 72 |
+
cur_frame_feat=cur_pix_feat,
|
| 73 |
+
last_mask=last_mask,
|
| 74 |
+
mem_val_diff=mem_val_diff)
|
| 75 |
+
|
| 76 |
+
prob = torch.sigmoid(logits)
|
| 77 |
+
mask = (prob > 0) + 0
|
| 78 |
+
|
| 79 |
+
uncert_output = {"logits": logits,
|
| 80 |
+
"prob": prob,
|
| 81 |
+
"mask": mask}
|
| 82 |
+
|
| 83 |
+
return uncert_output
|
| 84 |
+
|
| 85 |
+
def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
|
| 86 |
+
image = (image - self.pixel_mean) / self.pixel_std
|
| 87 |
+
ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
|
| 88 |
+
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
|
| 89 |
+
|
| 90 |
+
def encode_mask(
|
| 91 |
+
self,
|
| 92 |
+
image: torch.Tensor,
|
| 93 |
+
ms_features: List[torch.Tensor],
|
| 94 |
+
sensory: torch.Tensor,
|
| 95 |
+
masks: torch.Tensor,
|
| 96 |
+
*,
|
| 97 |
+
deep_update: bool = True,
|
| 98 |
+
chunk_size: int = -1,
|
| 99 |
+
need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 100 |
+
image = (image - self.pixel_mean) / self.pixel_std
|
| 101 |
+
others = self._get_others(masks)
|
| 102 |
+
mask_value, new_sensory = self.mask_encoder(image,
|
| 103 |
+
ms_features,
|
| 104 |
+
sensory,
|
| 105 |
+
masks,
|
| 106 |
+
others,
|
| 107 |
+
deep_update=deep_update,
|
| 108 |
+
chunk_size=chunk_size)
|
| 109 |
+
object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
|
| 110 |
+
return mask_value, new_sensory, object_summaries, object_logits
|
| 111 |
+
|
| 112 |
+
def transform_key(self,
|
| 113 |
+
final_pix_feat: torch.Tensor,
|
| 114 |
+
*,
|
| 115 |
+
need_sk: bool = True,
|
| 116 |
+
need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 117 |
+
key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
|
| 118 |
+
return key, shrinkage, selection
|
| 119 |
+
|
| 120 |
+
# Used in training only.
|
| 121 |
+
# This step is replaced by MemoryManager in test time
|
| 122 |
+
def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
|
| 123 |
+
memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
|
| 124 |
+
msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
| 125 |
+
sensory: torch.Tensor, last_mask: torch.Tensor,
|
| 126 |
+
selector: torch.Tensor, uncert_output=None, seg_pass=False,
|
| 127 |
+
last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 128 |
+
"""
|
| 129 |
+
query_key : B * CK * H * W
|
| 130 |
+
query_selection : B * CK * H * W
|
| 131 |
+
memory_key : B * CK * T * H * W
|
| 132 |
+
memory_shrinkage: B * 1 * T * H * W
|
| 133 |
+
msk_value : B * num_objects * CV * T * H * W
|
| 134 |
+
obj_memory : B * num_objects * T * num_summaries * C
|
| 135 |
+
pixel_feature : B * C * H * W
|
| 136 |
+
"""
|
| 137 |
+
batch_size, num_objects = msk_value.shape[:2]
|
| 138 |
+
|
| 139 |
+
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
| 140 |
+
|
| 141 |
+
# read using visual attention
|
| 142 |
+
with torch.amp.autocast("cuda",enabled=False):
|
| 143 |
+
affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
|
| 144 |
+
query_selection.float(), uncert_mask=uncert_mask)
|
| 145 |
+
|
| 146 |
+
msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
|
| 147 |
+
|
| 148 |
+
# B * (num_objects*CV) * H * W
|
| 149 |
+
pixel_readout = readout(affinity, msk_value, uncert_mask)
|
| 150 |
+
pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
|
| 151 |
+
*pixel_readout.shape[-2:])
|
| 152 |
+
|
| 153 |
+
uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
|
| 154 |
+
uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
|
| 155 |
+
pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
|
| 156 |
+
|
| 157 |
+
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# read from query transformer
|
| 161 |
+
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
|
| 162 |
+
|
| 163 |
+
aux_output = {
|
| 164 |
+
'sensory': sensory,
|
| 165 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
| 166 |
+
'attn_mask': aux_features['attn_mask'] if aux_features else None,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
return mem_readout, aux_output, uncert_output
|
| 170 |
+
|
| 171 |
+
def read_first_frame_memory(self, pixel_readout,
|
| 172 |
+
obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
| 173 |
+
sensory: torch.Tensor, last_mask: torch.Tensor,
|
| 174 |
+
selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 175 |
+
"""
|
| 176 |
+
query_key : B * CK * H * W
|
| 177 |
+
query_selection : B * CK * H * W
|
| 178 |
+
memory_key : B * CK * T * H * W
|
| 179 |
+
memory_shrinkage: B * 1 * T * H * W
|
| 180 |
+
msk_value : B * num_objects * CV * T * H * W
|
| 181 |
+
obj_memory : B * num_objects * T * num_summaries * C
|
| 182 |
+
pixel_feature : B * C * H * W
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
|
| 186 |
+
|
| 187 |
+
# read from query transformer
|
| 188 |
+
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
|
| 189 |
+
|
| 190 |
+
aux_output = {
|
| 191 |
+
'sensory': sensory,
|
| 192 |
+
'q_logits': aux_features['logits'] if aux_features else None,
|
| 193 |
+
'attn_mask': aux_features['attn_mask'] if aux_features else None,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
return mem_readout, aux_output
|
| 197 |
+
|
| 198 |
+
def pixel_fusion(self,
|
| 199 |
+
pix_feat: torch.Tensor,
|
| 200 |
+
pixel: torch.Tensor,
|
| 201 |
+
sensory: torch.Tensor,
|
| 202 |
+
last_mask: torch.Tensor,
|
| 203 |
+
*,
|
| 204 |
+
chunk_size: int = -1) -> torch.Tensor:
|
| 205 |
+
last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
|
| 206 |
+
last_others = self._get_others(last_mask)
|
| 207 |
+
fused = self.pixel_fuser(pix_feat,
|
| 208 |
+
pixel,
|
| 209 |
+
sensory,
|
| 210 |
+
last_mask,
|
| 211 |
+
last_others,
|
| 212 |
+
chunk_size=chunk_size)
|
| 213 |
+
return fused
|
| 214 |
+
|
| 215 |
+
def readout_query(self,
|
| 216 |
+
pixel_readout,
|
| 217 |
+
obj_memory,
|
| 218 |
+
*,
|
| 219 |
+
selector=None,
|
| 220 |
+
need_weights=False,
|
| 221 |
+
seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 222 |
+
return self.object_transformer(pixel_readout,
|
| 223 |
+
obj_memory,
|
| 224 |
+
selector=selector,
|
| 225 |
+
need_weights=need_weights,
|
| 226 |
+
seg_pass=seg_pass)
|
| 227 |
+
|
| 228 |
+
def segment(self,
|
| 229 |
+
ms_image_feat: List[torch.Tensor],
|
| 230 |
+
memory_readout: torch.Tensor,
|
| 231 |
+
sensory: torch.Tensor,
|
| 232 |
+
*,
|
| 233 |
+
selector: bool = None,
|
| 234 |
+
chunk_size: int = -1,
|
| 235 |
+
update_sensory: bool = True,
|
| 236 |
+
seg_pass: bool = False,
|
| 237 |
+
clamp_mat: bool = True,
|
| 238 |
+
last_mask=None,
|
| 239 |
+
sigmoid_residual=False,
|
| 240 |
+
seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 241 |
+
"""
|
| 242 |
+
multi_scale_features is from the key encoder for skip-connection
|
| 243 |
+
memory_readout is from working/long-term memory
|
| 244 |
+
sensory is the sensory memory
|
| 245 |
+
last_mask is the mask from the last frame, supplementing sensory memory
|
| 246 |
+
selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
|
| 247 |
+
during training.
|
| 248 |
+
"""
|
| 249 |
+
#### use mat head for seg data
|
| 250 |
+
if seg_mat:
|
| 251 |
+
assert seg_pass
|
| 252 |
+
seg_pass = False
|
| 253 |
+
####
|
| 254 |
+
sensory, logits = self.mask_decoder(ms_image_feat,
|
| 255 |
+
memory_readout,
|
| 256 |
+
sensory,
|
| 257 |
+
chunk_size=chunk_size,
|
| 258 |
+
update_sensory=update_sensory,
|
| 259 |
+
seg_pass = seg_pass,
|
| 260 |
+
last_mask=last_mask,
|
| 261 |
+
sigmoid_residual=sigmoid_residual)
|
| 262 |
+
if seg_pass:
|
| 263 |
+
prob = torch.sigmoid(logits)
|
| 264 |
+
if selector is not None:
|
| 265 |
+
prob = prob * selector
|
| 266 |
+
|
| 267 |
+
# Softmax over all objects[]
|
| 268 |
+
logits = aggregate(prob, dim=1)
|
| 269 |
+
prob = F.softmax(logits, dim=1)
|
| 270 |
+
else:
|
| 271 |
+
if clamp_mat:
|
| 272 |
+
logits = logits.clamp(0.0, 1.0)
|
| 273 |
+
logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
|
| 274 |
+
prob = logits
|
| 275 |
+
|
| 276 |
+
return sensory, logits, prob
|
| 277 |
+
|
| 278 |
+
def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
|
| 279 |
+
selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
|
| 280 |
+
return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
|
| 281 |
+
|
| 282 |
+
def forward(self, *args, **kwargs):
|
| 283 |
+
raise NotImplementedError
|
| 284 |
+
|
| 285 |
+
def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
|
| 286 |
+
if not self.single_object:
|
| 287 |
+
# Map single-object weight to multi-object weight (4->5 out channels in conv1)
|
| 288 |
+
for k in list(src_dict.keys()):
|
| 289 |
+
if k == 'mask_encoder.conv1.weight':
|
| 290 |
+
if src_dict[k].shape[1] == 4:
|
| 291 |
+
log.info(f'Converting {k} from single object to multiple objects.')
|
| 292 |
+
pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
|
| 293 |
+
if not init_as_zero_if_needed:
|
| 294 |
+
nn.init.orthogonal_(pads)
|
| 295 |
+
log.info(f'Randomly initialized padding for {k}.')
|
| 296 |
+
else:
|
| 297 |
+
log.info(f'Zero-initialized padding for {k}.')
|
| 298 |
+
src_dict[k] = torch.cat([src_dict[k], pads], 1)
|
| 299 |
+
elif k == 'pixel_fuser.sensory_compress.weight':
|
| 300 |
+
if src_dict[k].shape[1] == self.sensory_dim + 1:
|
| 301 |
+
log.info(f'Converting {k} from single object to multiple objects.')
|
| 302 |
+
pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
|
| 303 |
+
if not init_as_zero_if_needed:
|
| 304 |
+
nn.init.orthogonal_(pads)
|
| 305 |
+
log.info(f'Randomly initialized padding for {k}.')
|
| 306 |
+
else:
|
| 307 |
+
log.info(f'Zero-initialized padding for {k}.')
|
| 308 |
+
src_dict[k] = torch.cat([src_dict[k], pads], 1)
|
| 309 |
+
elif self.single_object:
|
| 310 |
+
"""
|
| 311 |
+
If the model is multiple-object and we are training in single-object,
|
| 312 |
+
we strip the last channel of conv1.
|
| 313 |
+
This is not supposed to happen in standard training except when users are trying to
|
| 314 |
+
finetune a trained model with single object datasets.
|
| 315 |
+
"""
|
| 316 |
+
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
|
| 317 |
+
log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
|
| 318 |
+
'This is not supposed to happen in standard training.')
|
| 319 |
+
src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
|
| 320 |
+
src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
|
| 321 |
+
|
| 322 |
+
for k in src_dict:
|
| 323 |
+
if k not in self.state_dict():
|
| 324 |
+
log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
|
| 325 |
+
for k in self.state_dict():
|
| 326 |
+
if k not in src_dict:
|
| 327 |
+
log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
|
| 328 |
+
|
| 329 |
+
self.load_state_dict(src_dict, strict=False)
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def device(self) -> torch.device:
|
| 333 |
+
return self.pixel_mean.device
|
hf_space/third_party/matanyone/model/modules.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Iterable
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UpsampleBlock(nn.Module):
|
| 10 |
+
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.out_conv = ResBlock(in_dim, out_dim)
|
| 13 |
+
self.scale_factor = scale_factor
|
| 14 |
+
|
| 15 |
+
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
g = F.interpolate(in_g,
|
| 17 |
+
scale_factor=self.scale_factor,
|
| 18 |
+
mode='bilinear')
|
| 19 |
+
g = self.out_conv(g)
|
| 20 |
+
g = g + skip_f
|
| 21 |
+
return g
|
| 22 |
+
|
| 23 |
+
class MaskUpsampleBlock(nn.Module):
|
| 24 |
+
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.distributor = MainToGroupDistributor(method='add')
|
| 27 |
+
self.out_conv = GroupResBlock(in_dim, out_dim)
|
| 28 |
+
self.scale_factor = scale_factor
|
| 29 |
+
|
| 30 |
+
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
g = upsample_groups(in_g, ratio=self.scale_factor)
|
| 32 |
+
g = self.distributor(skip_f, g)
|
| 33 |
+
g = self.out_conv(g)
|
| 34 |
+
return g
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DecoderFeatureProcessor(nn.Module):
|
| 38 |
+
def __init__(self, decoder_dims: List[int], out_dims: List[int]):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.transforms = nn.ModuleList([
|
| 41 |
+
nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
|
| 45 |
+
outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
|
| 46 |
+
return outputs
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# @torch.jit.script
|
| 50 |
+
def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
# h: batch_size * num_objects * hidden_dim * h * w
|
| 52 |
+
# values: batch_size * num_objects * (hidden_dim*3) * h * w
|
| 53 |
+
dim = values.shape[2] // 3
|
| 54 |
+
forget_gate = torch.sigmoid(values[:, :, :dim])
|
| 55 |
+
update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
|
| 56 |
+
new_value = torch.tanh(values[:, :, dim * 2:])
|
| 57 |
+
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
|
| 58 |
+
return new_h
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SensoryUpdater_fullscale(nn.Module):
|
| 62 |
+
# Used in the decoder, multi-scale feature + GRU
|
| 63 |
+
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
|
| 66 |
+
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
|
| 67 |
+
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
|
| 68 |
+
self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
|
| 69 |
+
self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
|
| 70 |
+
|
| 71 |
+
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
|
| 72 |
+
|
| 73 |
+
nn.init.xavier_normal_(self.transform.weight)
|
| 74 |
+
|
| 75 |
+
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
| 77 |
+
self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
|
| 78 |
+
self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
|
| 79 |
+
self.g1_conv(downsample_groups(g[4], ratio=1/16))
|
| 80 |
+
|
| 81 |
+
with torch.amp.autocast("cuda",enabled=False):
|
| 82 |
+
g = g.float()
|
| 83 |
+
h = h.float()
|
| 84 |
+
values = self.transform(torch.cat([g, h], dim=2))
|
| 85 |
+
new_h = _recurrent_update(h, values)
|
| 86 |
+
|
| 87 |
+
return new_h
|
| 88 |
+
|
| 89 |
+
class SensoryUpdater(nn.Module):
|
| 90 |
+
# Used in the decoder, multi-scale feature + GRU
|
| 91 |
+
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
|
| 94 |
+
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
|
| 95 |
+
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
|
| 96 |
+
|
| 97 |
+
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
|
| 98 |
+
|
| 99 |
+
nn.init.xavier_normal_(self.transform.weight)
|
| 100 |
+
|
| 101 |
+
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
| 103 |
+
self.g4_conv(downsample_groups(g[2], ratio=1/4))
|
| 104 |
+
|
| 105 |
+
with torch.amp.autocast("cuda",enabled=False):
|
| 106 |
+
g = g.float()
|
| 107 |
+
h = h.float()
|
| 108 |
+
values = self.transform(torch.cat([g, h], dim=2))
|
| 109 |
+
new_h = _recurrent_update(h, values)
|
| 110 |
+
|
| 111 |
+
return new_h
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SensoryDeepUpdater(nn.Module):
|
| 115 |
+
def __init__(self, f_dim: int, sensory_dim: int):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
|
| 118 |
+
|
| 119 |
+
nn.init.xavier_normal_(self.transform.weight)
|
| 120 |
+
|
| 121 |
+
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
with torch.amp.autocast("cuda",enabled=False):
|
| 123 |
+
g = g.float()
|
| 124 |
+
h = h.float()
|
| 125 |
+
values = self.transform(torch.cat([g, h], dim=2))
|
| 126 |
+
new_h = _recurrent_update(h, values)
|
| 127 |
+
|
| 128 |
+
return new_h
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ResBlock(nn.Module):
|
| 132 |
+
def __init__(self, in_dim: int, out_dim: int):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
if in_dim == out_dim:
|
| 136 |
+
self.downsample = nn.Identity()
|
| 137 |
+
else:
|
| 138 |
+
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 139 |
+
|
| 140 |
+
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
|
| 141 |
+
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
|
| 142 |
+
|
| 143 |
+
def forward(self, g: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
out_g = self.conv1(F.relu(g))
|
| 145 |
+
out_g = self.conv2(F.relu(out_g))
|
| 146 |
+
|
| 147 |
+
g = self.downsample(g)
|
| 148 |
+
|
| 149 |
+
return out_g + g
|
hf_space/third_party/matanyone/model/transformer/__init__.py
ADDED
|
File without changes
|
hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (207 Bytes). View file
|
|
|
hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc
ADDED
|
Binary file (5.97 kB). View file
|
|
|
hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc
ADDED
|
Binary file (8.94 kB). View file
|
|
|
hf_space/third_party/matanyone/model/transformer/object_summarizer.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# @torch.jit.script
|
| 11 |
+
def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
|
| 12 |
+
logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
| 13 |
+
# value: B*num_objects*H*W*value_dim
|
| 14 |
+
# logits: B*num_objects*H*W*num_summaries
|
| 15 |
+
# masks: B*num_objects*H*W*num_summaries: 1 if allowed
|
| 16 |
+
weights = logits.sigmoid() * masks
|
| 17 |
+
# B*num_objects*num_summaries*value_dim
|
| 18 |
+
sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
|
| 19 |
+
# B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
|
| 20 |
+
area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
|
| 21 |
+
|
| 22 |
+
# B*num_objects*num_summaries*value_dim
|
| 23 |
+
return sums, area
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ObjectSummarizer(nn.Module):
|
| 27 |
+
def __init__(self, model_cfg: DictConfig):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
this_cfg = model_cfg.object_summarizer
|
| 31 |
+
self.value_dim = model_cfg.value_dim
|
| 32 |
+
self.embed_dim = this_cfg.embed_dim
|
| 33 |
+
self.num_summaries = this_cfg.num_summaries
|
| 34 |
+
self.add_pe = this_cfg.add_pe
|
| 35 |
+
self.pixel_pe_scale = model_cfg.pixel_pe_scale
|
| 36 |
+
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
|
| 37 |
+
|
| 38 |
+
if self.add_pe:
|
| 39 |
+
self.pos_enc = PositionalEncoding(self.embed_dim,
|
| 40 |
+
scale=self.pixel_pe_scale,
|
| 41 |
+
temperature=self.pixel_pe_temperature)
|
| 42 |
+
|
| 43 |
+
self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
|
| 44 |
+
self.feature_pred = nn.Sequential(
|
| 45 |
+
nn.Linear(self.embed_dim, self.embed_dim),
|
| 46 |
+
nn.ReLU(inplace=True),
|
| 47 |
+
nn.Linear(self.embed_dim, self.embed_dim),
|
| 48 |
+
)
|
| 49 |
+
self.weights_pred = nn.Sequential(
|
| 50 |
+
nn.Linear(self.embed_dim, self.embed_dim),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
nn.Linear(self.embed_dim, self.num_summaries),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self,
|
| 56 |
+
masks: torch.Tensor,
|
| 57 |
+
value: torch.Tensor,
|
| 58 |
+
need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
|
| 59 |
+
# masks: B*num_objects*(H0)*(W0)
|
| 60 |
+
# value: B*num_objects*value_dim*H*W
|
| 61 |
+
# -> B*num_objects*H*W*value_dim
|
| 62 |
+
h, w = value.shape[-2:]
|
| 63 |
+
masks = F.interpolate(masks, size=(h, w), mode='area')
|
| 64 |
+
masks = masks.unsqueeze(-1)
|
| 65 |
+
inv_masks = 1 - masks
|
| 66 |
+
repeated_masks = torch.cat([
|
| 67 |
+
masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
|
| 68 |
+
inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
|
| 69 |
+
],
|
| 70 |
+
dim=-1)
|
| 71 |
+
|
| 72 |
+
value = value.permute(0, 1, 3, 4, 2)
|
| 73 |
+
value = self.input_proj(value)
|
| 74 |
+
if self.add_pe:
|
| 75 |
+
pe = self.pos_enc(value)
|
| 76 |
+
value = value + pe
|
| 77 |
+
|
| 78 |
+
with torch.amp.autocast("cuda",enabled=False):
|
| 79 |
+
value = value.float()
|
| 80 |
+
feature = self.feature_pred(value)
|
| 81 |
+
logits = self.weights_pred(value)
|
| 82 |
+
sums, area = _weighted_pooling(repeated_masks, feature, logits)
|
| 83 |
+
|
| 84 |
+
summaries = torch.cat([sums, area], dim=-1)
|
| 85 |
+
|
| 86 |
+
if need_weights:
|
| 87 |
+
return summaries, logits
|
| 88 |
+
else:
|
| 89 |
+
return summaries, None
|
hf_space/third_party/matanyone/model/transformer/object_transformer.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
from omegaconf import DictConfig
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from matanyone.model.group_modules import GConv2d
|
| 7 |
+
from matanyone.utils.tensor_utils import aggregate
|
| 8 |
+
from matanyone.model.transformer.positional_encoding import PositionalEncoding
|
| 9 |
+
from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class QueryTransformerBlock(nn.Module):
|
| 13 |
+
def __init__(self, model_cfg: DictConfig):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
this_cfg = model_cfg.object_transformer
|
| 17 |
+
self.embed_dim = this_cfg.embed_dim
|
| 18 |
+
self.num_heads = this_cfg.num_heads
|
| 19 |
+
self.num_queries = this_cfg.num_queries
|
| 20 |
+
self.ff_dim = this_cfg.ff_dim
|
| 21 |
+
|
| 22 |
+
self.read_from_pixel = CrossAttention(self.embed_dim,
|
| 23 |
+
self.num_heads,
|
| 24 |
+
add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
|
| 25 |
+
self.self_attn = SelfAttention(self.embed_dim,
|
| 26 |
+
self.num_heads,
|
| 27 |
+
add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
|
| 28 |
+
self.ffn = FFN(self.embed_dim, self.ff_dim)
|
| 29 |
+
self.read_from_query = CrossAttention(self.embed_dim,
|
| 30 |
+
self.num_heads,
|
| 31 |
+
add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
|
| 32 |
+
norm=this_cfg.read_from_query.output_norm)
|
| 33 |
+
self.pixel_ffn = PixelFFN(self.embed_dim)
|
| 34 |
+
|
| 35 |
+
def forward(
|
| 36 |
+
self,
|
| 37 |
+
x: torch.Tensor,
|
| 38 |
+
pixel: torch.Tensor,
|
| 39 |
+
query_pe: torch.Tensor,
|
| 40 |
+
pixel_pe: torch.Tensor,
|
| 41 |
+
attn_mask: torch.Tensor,
|
| 42 |
+
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
|
| 43 |
+
# x: (bs*num_objects)*num_queries*embed_dim
|
| 44 |
+
# pixel: bs*num_objects*C*H*W
|
| 45 |
+
# query_pe: (bs*num_objects)*num_queries*embed_dim
|
| 46 |
+
# pixel_pe: (bs*num_objects)*(H*W)*C
|
| 47 |
+
# attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
|
| 48 |
+
|
| 49 |
+
# bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
|
| 50 |
+
pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
|
| 51 |
+
x, q_weights = self.read_from_pixel(x,
|
| 52 |
+
pixel_flat,
|
| 53 |
+
query_pe,
|
| 54 |
+
pixel_pe,
|
| 55 |
+
attn_mask=attn_mask,
|
| 56 |
+
need_weights=need_weights)
|
| 57 |
+
x = self.self_attn(x, query_pe)
|
| 58 |
+
x = self.ffn(x)
|
| 59 |
+
|
| 60 |
+
pixel_flat, p_weights = self.read_from_query(pixel_flat,
|
| 61 |
+
x,
|
| 62 |
+
pixel_pe,
|
| 63 |
+
query_pe,
|
| 64 |
+
need_weights=need_weights)
|
| 65 |
+
pixel = self.pixel_ffn(pixel, pixel_flat)
|
| 66 |
+
|
| 67 |
+
if need_weights:
|
| 68 |
+
bs, num_objects, _, h, w = pixel.shape
|
| 69 |
+
q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
|
| 70 |
+
p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
|
| 71 |
+
self.num_queries, h, w)
|
| 72 |
+
|
| 73 |
+
return x, pixel, q_weights, p_weights
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class QueryTransformer(nn.Module):
|
| 77 |
+
def __init__(self, model_cfg: DictConfig):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
this_cfg = model_cfg.object_transformer
|
| 81 |
+
self.value_dim = model_cfg.value_dim
|
| 82 |
+
self.embed_dim = this_cfg.embed_dim
|
| 83 |
+
self.num_heads = this_cfg.num_heads
|
| 84 |
+
self.num_queries = this_cfg.num_queries
|
| 85 |
+
|
| 86 |
+
# query initialization and embedding
|
| 87 |
+
self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
|
| 88 |
+
self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
|
| 89 |
+
|
| 90 |
+
# projection from object summaries to query initialization and embedding
|
| 91 |
+
self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
|
| 92 |
+
self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
|
| 93 |
+
|
| 94 |
+
self.pixel_pe_scale = model_cfg.pixel_pe_scale
|
| 95 |
+
self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
|
| 96 |
+
self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
|
| 97 |
+
self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
|
| 98 |
+
self.spatial_pe = PositionalEncoding(self.embed_dim,
|
| 99 |
+
scale=self.pixel_pe_scale,
|
| 100 |
+
temperature=self.pixel_pe_temperature,
|
| 101 |
+
channel_last=False,
|
| 102 |
+
transpose_output=True)
|
| 103 |
+
|
| 104 |
+
# transformer blocks
|
| 105 |
+
self.num_blocks = this_cfg.num_blocks
|
| 106 |
+
self.blocks = nn.ModuleList(
|
| 107 |
+
QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
|
| 108 |
+
self.mask_pred = nn.ModuleList(
|
| 109 |
+
nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
|
| 110 |
+
for _ in range(self.num_blocks + 1))
|
| 111 |
+
|
| 112 |
+
self.act = nn.ReLU(inplace=True)
|
| 113 |
+
|
| 114 |
+
def forward(self,
|
| 115 |
+
pixel: torch.Tensor,
|
| 116 |
+
obj_summaries: torch.Tensor,
|
| 117 |
+
selector: Optional[torch.Tensor] = None,
|
| 118 |
+
need_weights: bool = False,
|
| 119 |
+
seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
|
| 120 |
+
# pixel: B*num_objects*embed_dim*H*W
|
| 121 |
+
# obj_summaries: B*num_objects*T*num_queries*embed_dim
|
| 122 |
+
T = obj_summaries.shape[2]
|
| 123 |
+
bs, num_objects, _, H, W = pixel.shape
|
| 124 |
+
|
| 125 |
+
# normalize object values
|
| 126 |
+
# the last channel is the cumulative area of the object
|
| 127 |
+
obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
|
| 128 |
+
self.embed_dim + 1)
|
| 129 |
+
# sum over time
|
| 130 |
+
# during inference, T=1 as we already did streaming average in memory_manager
|
| 131 |
+
obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
|
| 132 |
+
obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
|
| 133 |
+
obj_values = obj_sums / (obj_area + 1e-4)
|
| 134 |
+
obj_init = self.summary_to_query_init(obj_values)
|
| 135 |
+
obj_emb = self.summary_to_query_emb(obj_values)
|
| 136 |
+
|
| 137 |
+
# positional embeddings for object queries
|
| 138 |
+
query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
|
| 139 |
+
query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
|
| 140 |
+
|
| 141 |
+
# positional embeddings for pixel features
|
| 142 |
+
pixel_init = self.pixel_init_proj(pixel)
|
| 143 |
+
pixel_emb = self.pixel_emb_proj(pixel)
|
| 144 |
+
pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
|
| 145 |
+
pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
|
| 146 |
+
pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
|
| 147 |
+
|
| 148 |
+
pixel = pixel_init
|
| 149 |
+
|
| 150 |
+
# run the transformer
|
| 151 |
+
aux_features = {'logits': []}
|
| 152 |
+
|
| 153 |
+
# first aux output
|
| 154 |
+
aux_logits = self.mask_pred[0](pixel).squeeze(2)
|
| 155 |
+
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
|
| 156 |
+
aux_features['logits'].append(aux_logits)
|
| 157 |
+
for i in range(self.num_blocks):
|
| 158 |
+
query, pixel, q_weights, p_weights = self.blocks[i](query,
|
| 159 |
+
pixel,
|
| 160 |
+
query_emb,
|
| 161 |
+
pixel_pe,
|
| 162 |
+
attn_mask,
|
| 163 |
+
need_weights=need_weights)
|
| 164 |
+
|
| 165 |
+
if self.training or i <= self.num_blocks - 1 or need_weights:
|
| 166 |
+
aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
|
| 167 |
+
attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
|
| 168 |
+
aux_features['logits'].append(aux_logits)
|
| 169 |
+
|
| 170 |
+
aux_features['q_weights'] = q_weights # last layer only
|
| 171 |
+
aux_features['p_weights'] = p_weights # last layer only
|
| 172 |
+
|
| 173 |
+
if self.training:
|
| 174 |
+
# no need to save all heads
|
| 175 |
+
aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
|
| 176 |
+
self.num_queries, H, W)[:, :, 0]
|
| 177 |
+
|
| 178 |
+
return pixel, aux_features
|
| 179 |
+
|
| 180 |
+
def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
|
| 181 |
+
# logits: batch_size*num_objects*H*W
|
| 182 |
+
# selector: batch_size*num_objects*1*1
|
| 183 |
+
# returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
|
| 184 |
+
# where True means the attention is blocked
|
| 185 |
+
|
| 186 |
+
if selector is None:
|
| 187 |
+
prob = logits.sigmoid()
|
| 188 |
+
else:
|
| 189 |
+
prob = logits.sigmoid() * selector
|
| 190 |
+
logits = aggregate(prob, dim=1)
|
| 191 |
+
|
| 192 |
+
is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
|
| 193 |
+
foreground_mask = is_foreground.bool().flatten(start_dim=2)
|
| 194 |
+
inv_foreground_mask = ~foreground_mask
|
| 195 |
+
inv_background_mask = foreground_mask
|
| 196 |
+
|
| 197 |
+
aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
|
| 198 |
+
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
|
| 199 |
+
aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
|
| 200 |
+
1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
|
| 201 |
+
|
| 202 |
+
aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
|
| 203 |
+
|
| 204 |
+
aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
|
| 205 |
+
|
| 206 |
+
return aux_mask
|
hf_space/third_party/matanyone/model/transformer/positional_encoding.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reference:
|
| 2 |
+
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
|
| 3 |
+
# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
"""
|
| 14 |
+
Gets a base embedding for one dimension with sin and cos intertwined
|
| 15 |
+
"""
|
| 16 |
+
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
|
| 17 |
+
return torch.flatten(emb, -2, -1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PositionalEncoding(nn.Module):
|
| 21 |
+
def __init__(self,
|
| 22 |
+
dim: int,
|
| 23 |
+
scale: float = math.pi * 2,
|
| 24 |
+
temperature: float = 10000,
|
| 25 |
+
normalize: bool = True,
|
| 26 |
+
channel_last: bool = True,
|
| 27 |
+
transpose_output: bool = False):
|
| 28 |
+
super().__init__()
|
| 29 |
+
dim = int(np.ceil(dim / 4) * 2)
|
| 30 |
+
self.dim = dim
|
| 31 |
+
inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
|
| 32 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 33 |
+
self.normalize = normalize
|
| 34 |
+
self.scale = scale
|
| 35 |
+
self.eps = 1e-6
|
| 36 |
+
self.channel_last = channel_last
|
| 37 |
+
self.transpose_output = transpose_output
|
| 38 |
+
|
| 39 |
+
self.cached_penc = None # the cache is irrespective of the number of objects
|
| 40 |
+
|
| 41 |
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
:param tensor: A 4/5d tensor of size
|
| 44 |
+
channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
|
| 45 |
+
channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
|
| 46 |
+
:return: positional encoding tensor that has the same shape as the input if the input is 4d
|
| 47 |
+
if the input is 5d, the output is broadcastable along the k-dimension
|
| 48 |
+
"""
|
| 49 |
+
if len(tensor.shape) != 4 and len(tensor.shape) != 5:
|
| 50 |
+
raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
|
| 51 |
+
|
| 52 |
+
if len(tensor.shape) == 5:
|
| 53 |
+
# take a sample from the k dimension
|
| 54 |
+
num_objects = tensor.shape[1]
|
| 55 |
+
tensor = tensor[:, 0]
|
| 56 |
+
else:
|
| 57 |
+
num_objects = None
|
| 58 |
+
|
| 59 |
+
if self.channel_last:
|
| 60 |
+
batch_size, h, w, c = tensor.shape
|
| 61 |
+
else:
|
| 62 |
+
batch_size, c, h, w = tensor.shape
|
| 63 |
+
|
| 64 |
+
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
|
| 65 |
+
if num_objects is None:
|
| 66 |
+
return self.cached_penc
|
| 67 |
+
else:
|
| 68 |
+
return self.cached_penc.unsqueeze(1)
|
| 69 |
+
|
| 70 |
+
self.cached_penc = None
|
| 71 |
+
|
| 72 |
+
pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
|
| 73 |
+
pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
|
| 74 |
+
if self.normalize:
|
| 75 |
+
pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
|
| 76 |
+
pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
|
| 77 |
+
|
| 78 |
+
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
|
| 79 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
| 80 |
+
emb_y = get_emb(sin_inp_y).unsqueeze(1)
|
| 81 |
+
emb_x = get_emb(sin_inp_x)
|
| 82 |
+
|
| 83 |
+
emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
|
| 84 |
+
emb[:, :, :self.dim] = emb_x
|
| 85 |
+
emb[:, :, self.dim:] = emb_y
|
| 86 |
+
|
| 87 |
+
if not self.channel_last and self.transpose_output:
|
| 88 |
+
# cancelled out
|
| 89 |
+
pass
|
| 90 |
+
elif (not self.channel_last) or (self.transpose_output):
|
| 91 |
+
emb = emb.permute(2, 0, 1)
|
| 92 |
+
|
| 93 |
+
self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
| 94 |
+
if num_objects is None:
|
| 95 |
+
return self.cached_penc
|
| 96 |
+
else:
|
| 97 |
+
return self.cached_penc.unsqueeze(1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
pe = PositionalEncoding(8).cuda()
|
| 102 |
+
input = torch.ones((1, 8, 8, 8)).cuda()
|
| 103 |
+
output = pe(input)
|
| 104 |
+
# print(output)
|
| 105 |
+
print(output[0, :, 0, 0])
|
| 106 |
+
print(output[0, :, 0, 5])
|
| 107 |
+
print(output[0, 0, :, 0])
|
| 108 |
+
print(output[0, 0, 0, :])
|
hf_space/third_party/matanyone/model/transformer/transformer_layers.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from PyTorch nn.Transformer
|
| 2 |
+
|
| 3 |
+
from typing import List, Callable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from matanyone.model.channel_attn import CAResBlock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SelfAttention(nn.Module):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
dim: int,
|
| 15 |
+
nhead: int,
|
| 16 |
+
dropout: float = 0.0,
|
| 17 |
+
batch_first: bool = True,
|
| 18 |
+
add_pe_to_qkv: List[bool] = [True, True, False]):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
|
| 21 |
+
self.norm = nn.LayerNorm(dim)
|
| 22 |
+
self.dropout = nn.Dropout(dropout)
|
| 23 |
+
self.add_pe_to_qkv = add_pe_to_qkv
|
| 24 |
+
|
| 25 |
+
def forward(self,
|
| 26 |
+
x: torch.Tensor,
|
| 27 |
+
pe: torch.Tensor,
|
| 28 |
+
attn_mask: bool = None,
|
| 29 |
+
key_padding_mask: bool = None) -> torch.Tensor:
|
| 30 |
+
x = self.norm(x)
|
| 31 |
+
if any(self.add_pe_to_qkv):
|
| 32 |
+
x_with_pe = x + pe
|
| 33 |
+
q = x_with_pe if self.add_pe_to_qkv[0] else x
|
| 34 |
+
k = x_with_pe if self.add_pe_to_qkv[1] else x
|
| 35 |
+
v = x_with_pe if self.add_pe_to_qkv[2] else x
|
| 36 |
+
else:
|
| 37 |
+
q = k = v = x
|
| 38 |
+
|
| 39 |
+
r = x
|
| 40 |
+
x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
|
| 41 |
+
return r + self.dropout(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
|
| 45 |
+
class CrossAttention(nn.Module):
|
| 46 |
+
def __init__(self,
|
| 47 |
+
dim: int,
|
| 48 |
+
nhead: int,
|
| 49 |
+
dropout: float = 0.0,
|
| 50 |
+
batch_first: bool = True,
|
| 51 |
+
add_pe_to_qkv: List[bool] = [True, True, False],
|
| 52 |
+
residual: bool = True,
|
| 53 |
+
norm: bool = True):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.cross_attn = nn.MultiheadAttention(dim,
|
| 56 |
+
nhead,
|
| 57 |
+
dropout=dropout,
|
| 58 |
+
batch_first=batch_first)
|
| 59 |
+
if norm:
|
| 60 |
+
self.norm = nn.LayerNorm(dim)
|
| 61 |
+
else:
|
| 62 |
+
self.norm = nn.Identity()
|
| 63 |
+
self.dropout = nn.Dropout(dropout)
|
| 64 |
+
self.add_pe_to_qkv = add_pe_to_qkv
|
| 65 |
+
self.residual = residual
|
| 66 |
+
|
| 67 |
+
def forward(self,
|
| 68 |
+
x: torch.Tensor,
|
| 69 |
+
mem: torch.Tensor,
|
| 70 |
+
x_pe: torch.Tensor,
|
| 71 |
+
mem_pe: torch.Tensor,
|
| 72 |
+
attn_mask: bool = None,
|
| 73 |
+
*,
|
| 74 |
+
need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
|
| 75 |
+
x = self.norm(x)
|
| 76 |
+
if self.add_pe_to_qkv[0]:
|
| 77 |
+
q = x + x_pe
|
| 78 |
+
else:
|
| 79 |
+
q = x
|
| 80 |
+
|
| 81 |
+
if any(self.add_pe_to_qkv[1:]):
|
| 82 |
+
mem_with_pe = mem + mem_pe
|
| 83 |
+
k = mem_with_pe if self.add_pe_to_qkv[1] else mem
|
| 84 |
+
v = mem_with_pe if self.add_pe_to_qkv[2] else mem
|
| 85 |
+
else:
|
| 86 |
+
k = v = mem
|
| 87 |
+
r = x
|
| 88 |
+
x, weights = self.cross_attn(q,
|
| 89 |
+
k,
|
| 90 |
+
v,
|
| 91 |
+
attn_mask=attn_mask,
|
| 92 |
+
need_weights=need_weights,
|
| 93 |
+
average_attn_weights=False)
|
| 94 |
+
|
| 95 |
+
if self.residual:
|
| 96 |
+
return r + self.dropout(x), weights
|
| 97 |
+
else:
|
| 98 |
+
return self.dropout(x), weights
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class FFN(nn.Module):
|
| 102 |
+
def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.linear1 = nn.Linear(dim_in, dim_ff)
|
| 105 |
+
self.linear2 = nn.Linear(dim_ff, dim_in)
|
| 106 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 107 |
+
|
| 108 |
+
if isinstance(activation, str):
|
| 109 |
+
self.activation = _get_activation_fn(activation)
|
| 110 |
+
else:
|
| 111 |
+
self.activation = activation
|
| 112 |
+
|
| 113 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
r = x
|
| 115 |
+
x = self.norm(x)
|
| 116 |
+
x = self.linear2(self.activation(self.linear1(x)))
|
| 117 |
+
x = r + x
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class PixelFFN(nn.Module):
|
| 122 |
+
def __init__(self, dim: int):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.dim = dim
|
| 125 |
+
self.conv = CAResBlock(dim, dim)
|
| 126 |
+
|
| 127 |
+
def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
# pixel: batch_size * num_objects * dim * H * W
|
| 129 |
+
# pixel_flat: (batch_size*num_objects) * (H*W) * dim
|
| 130 |
+
bs, num_objects, _, h, w = pixel.shape
|
| 131 |
+
pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
|
| 132 |
+
pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
|
| 133 |
+
|
| 134 |
+
x = self.conv(pixel_flat)
|
| 135 |
+
x = x.view(bs, num_objects, self.dim, h, w)
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class OutputFFN(nn.Module):
|
| 140 |
+
def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.linear1 = nn.Linear(dim_in, dim_out)
|
| 143 |
+
self.linear2 = nn.Linear(dim_out, dim_out)
|
| 144 |
+
|
| 145 |
+
if isinstance(activation, str):
|
| 146 |
+
self.activation = _get_activation_fn(activation)
|
| 147 |
+
else:
|
| 148 |
+
self.activation = activation
|
| 149 |
+
|
| 150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
x = self.linear2(self.activation(self.linear1(x)))
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
| 156 |
+
if activation == "relu":
|
| 157 |
+
return F.relu
|
| 158 |
+
elif activation == "gelu":
|
| 159 |
+
return F.gelu
|
| 160 |
+
|
| 161 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
hf_space/third_party/matanyone/model/utils/__init__.py
ADDED
|
File without changes
|
hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|