MogensR commited on
Commit
8811fd7
·
verified ·
1 Parent(s): 456bf70

Upload 61 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. hf_space/third_party/matanyone/__init__.py +2 -2
  2. hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc +0 -0
  3. hf_space/third_party/matanyone/config/__init__.py +0 -0
  4. hf_space/third_party/matanyone/config/eval_matanyone_config.yaml +47 -0
  5. hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml +22 -0
  6. hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml +22 -0
  7. hf_space/third_party/matanyone/config/model/base.yaml +58 -0
  8. hf_space/third_party/matanyone/inference/__init__.py +0 -0
  9. hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc +0 -0
  10. hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc +0 -0
  11. hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc +0 -0
  12. hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc +0 -0
  13. hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc +0 -0
  14. hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc +0 -0
  15. hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc +0 -0
  16. hf_space/third_party/matanyone/inference/image_feature_store.py +56 -0
  17. hf_space/third_party/matanyone/inference/inference_core.py +545 -0
  18. hf_space/third_party/matanyone/inference/kv_memory_store.py +348 -0
  19. hf_space/third_party/matanyone/inference/memory_manager.py +453 -0
  20. hf_space/third_party/matanyone/inference/object_info.py +24 -0
  21. hf_space/third_party/matanyone/inference/object_manager.py +149 -0
  22. hf_space/third_party/matanyone/inference/utils/__init__.py +0 -0
  23. hf_space/third_party/matanyone/inference/utils/args_utils.py +30 -0
  24. hf_space/third_party/matanyone/model/__init__.py +0 -0
  25. hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc +0 -0
  26. hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc +0 -0
  27. hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc +0 -0
  28. hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc +0 -0
  29. hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc +0 -0
  30. hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc +0 -0
  31. hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc +0 -0
  32. hf_space/third_party/matanyone/model/aux_modules.py +93 -0
  33. hf_space/third_party/matanyone/model/big_modules.py +365 -0
  34. hf_space/third_party/matanyone/model/channel_attn.py +39 -0
  35. hf_space/third_party/matanyone/model/group_modules.py +126 -0
  36. hf_space/third_party/matanyone/model/matanyone.py +333 -0
  37. hf_space/third_party/matanyone/model/modules.py +149 -0
  38. hf_space/third_party/matanyone/model/transformer/__init__.py +0 -0
  39. hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc +0 -0
  40. hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc +0 -0
  41. hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc +0 -0
  42. hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc +0 -0
  43. hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc +0 -0
  44. hf_space/third_party/matanyone/model/transformer/object_summarizer.py +89 -0
  45. hf_space/third_party/matanyone/model/transformer/object_transformer.py +206 -0
  46. hf_space/third_party/matanyone/model/transformer/positional_encoding.py +108 -0
  47. hf_space/third_party/matanyone/model/transformer/transformer_layers.py +161 -0
  48. hf_space/third_party/matanyone/model/utils/__init__.py +0 -0
  49. hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc +0 -0
  50. hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc +0 -0
hf_space/third_party/matanyone/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- # Placeholder for vendored MatAnyone package. Replace with the real 'matanyone' package from https://github.com/pq-yang/MatAnyone.
2
- __all__ = []
 
1
+ from matanyone.inference.inference_core import InferenceCore
2
+ from matanyone.model.matanyone import MatAnyone
hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (324 Bytes). View file
 
hf_space/third_party/matanyone/config/__init__.py ADDED
File without changes
hf_space/third_party/matanyone/config/eval_matanyone_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - model: base
4
+ - override hydra/job_logging: custom-no-rank.yaml
5
+
6
+ hydra:
7
+ run:
8
+ dir: ../output/${exp_id}/${dataset}
9
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ amp: False
12
+ weights: pretrained_models/matanyone.pth # default (can be modified from outside)
13
+ output_dir: null # defaults to run_dir; specify this to override
14
+ flip_aug: False
15
+
16
+
17
+ # maximum shortest side of the input; -1 means no resizing
18
+ # With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
19
+ # this parameter is added for the sole purpose for the GUI in the current codebase
20
+ # InferenceCore will downsize the input and restore the output to the original size if needed
21
+ # if you are using this code for some other project, you can also utilize this parameter
22
+ max_internal_size: -1
23
+
24
+ # these parameters, when set, override the dataset's default; useful for debugging
25
+ save_all: True
26
+ use_all_masks: False
27
+ use_long_term: False
28
+ mem_every: 5
29
+
30
+ # only relevant when long_term is not enabled
31
+ max_mem_frames: 5
32
+
33
+ # only relevant when long_term is enabled
34
+ long_term:
35
+ count_usage: True
36
+ max_mem_frames: 10
37
+ min_mem_frames: 5
38
+ num_prototypes: 128
39
+ max_num_tokens: 10000
40
+ buffer_tokens: 2000
41
+
42
+ top_k: 30
43
+ stagger_updates: 5
44
+ chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
45
+ save_scores: False
46
+ save_aux: False
47
+ visualize: False
hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: simple
11
+ stream: ext://sys.stdout
12
+ file:
13
+ class: logging.FileHandler
14
+ formatter: simple
15
+ # absolute file path
16
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
17
+ mode: w
18
+ root:
19
+ level: INFO
20
+ handlers: [console, file]
21
+
22
+ disable_existing_loggers: false
hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: simple
11
+ stream: ext://sys.stdout
12
+ file:
13
+ class: logging.FileHandler
14
+ formatter: simple
15
+ # absolute file path
16
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
17
+ mode: w
18
+ root:
19
+ level: INFO
20
+ handlers: [console, file]
21
+
22
+ disable_existing_loggers: false
hf_space/third_party/matanyone/config/model/base.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pixel_mean: [0.485, 0.456, 0.406]
2
+ pixel_std: [0.229, 0.224, 0.225]
3
+
4
+ pixel_dim: 256
5
+ key_dim: 64
6
+ value_dim: 256
7
+ sensory_dim: 256
8
+ embed_dim: 256
9
+
10
+ pixel_encoder:
11
+ type: resnet50
12
+ ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
13
+
14
+ mask_encoder:
15
+ type: resnet18
16
+ final_dim: 256
17
+
18
+ pixel_pe_scale: 32
19
+ pixel_pe_temperature: 128
20
+
21
+ object_transformer:
22
+ embed_dim: ${model.embed_dim}
23
+ ff_dim: 2048
24
+ num_heads: 8
25
+ num_blocks: 3
26
+ num_queries: 16
27
+ read_from_pixel:
28
+ input_norm: False
29
+ input_add_pe: False
30
+ add_pe_to_qkv: [True, True, False]
31
+ read_from_past:
32
+ add_pe_to_qkv: [True, True, False]
33
+ read_from_memory:
34
+ add_pe_to_qkv: [True, True, False]
35
+ read_from_query:
36
+ add_pe_to_qkv: [True, True, False]
37
+ output_norm: False
38
+ query_self_attention:
39
+ add_pe_to_qkv: [True, True, False]
40
+ pixel_self_attention:
41
+ add_pe_to_qkv: [True, True, False]
42
+
43
+ object_summarizer:
44
+ embed_dim: ${model.object_transformer.embed_dim}
45
+ num_summaries: ${model.object_transformer.num_queries}
46
+ add_pe: True
47
+
48
+ aux_loss:
49
+ sensory:
50
+ enabled: True
51
+ weight: 0.01
52
+ query:
53
+ enabled: True
54
+ weight: 0.01
55
+
56
+ mask_decoder:
57
+ # first value must equal embed_dim
58
+ up_dims: [256, 128, 128, 64, 16]
hf_space/third_party/matanyone/inference/__init__.py ADDED
File without changes
hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (199 Bytes). View file
 
hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc ADDED
Binary file (4.71 kB). View file
 
hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc ADDED
Binary file (26.7 kB). View file
 
hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc ADDED
Binary file (18.8 kB). View file
 
hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc ADDED
Binary file (22.5 kB). View file
 
hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc ADDED
Binary file (1.68 kB). View file
 
hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc ADDED
Binary file (8.07 kB). View file
 
hf_space/third_party/matanyone/inference/image_feature_store.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Iterable
3
+ import torch
4
+ from matanyone.model.matanyone import MatAnyone
5
+
6
+
7
+ class ImageFeatureStore:
8
+ """
9
+ A cache for image features.
10
+ These features might be reused at different parts of the inference pipeline.
11
+ This class provide an interface for reusing these features.
12
+ It is the user's responsibility to delete redundant features.
13
+
14
+ Feature of a frame should be associated with a unique index -- typically the frame id.
15
+ """
16
+ def __init__(self, network: MatAnyone, no_warning: bool = False):
17
+ self.network = network
18
+ self._store = {}
19
+ self.no_warning = no_warning
20
+
21
+ def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
22
+ ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
23
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
24
+ self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
25
+
26
+ def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
27
+ seq_length = images.shape[0]
28
+ ms_features, pix_feat = self.network.encode_image(images, seq_length)
29
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
30
+ for index in range(seq_length):
31
+ self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
32
+
33
+ def get_features(self, index: int,
34
+ image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
35
+ if index not in self._store:
36
+ self._encode_feature(index, image, last_feats)
37
+
38
+ return self._store[index][:2]
39
+
40
+ def get_key(self, index: int,
41
+ image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
42
+ if index not in self._store:
43
+ self._encode_feature(index, image, last_feats)
44
+
45
+ return self._store[index][2:]
46
+
47
+ def delete(self, index: int) -> None:
48
+ if index in self._store:
49
+ del self._store[index]
50
+
51
+ def __len__(self):
52
+ return len(self._store)
53
+
54
+ def __del__(self):
55
+ if len(self._store) > 0 and not self.no_warning:
56
+ warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
hf_space/third_party/matanyone/inference/inference_core.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+ from typing import List, Optional, Iterable, Union,Tuple
4
+
5
+ import os
6
+ import cv2
7
+ import torch
8
+ import imageio
9
+ import tempfile
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ from PIL import Image
13
+ import torch.nn.functional as F
14
+
15
+ from matanyone.inference.memory_manager import MemoryManager
16
+ from matanyone.inference.object_manager import ObjectManager
17
+ from matanyone.inference.image_feature_store import ImageFeatureStore
18
+ from matanyone.model.matanyone import MatAnyone
19
+ from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate
20
+ from matanyone.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos
21
+
22
+ log = logging.getLogger()
23
+
24
+
25
+ class InferenceCore:
26
+
27
+ def __init__(self,
28
+ network: Union[MatAnyone,str],
29
+ cfg: DictConfig = None,
30
+ *,
31
+ image_feature_store: ImageFeatureStore = None,
32
+ device: Union[str, torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ ):
34
+ if isinstance(network, str):
35
+ network = MatAnyone.from_pretrained(network)
36
+ network.to(device)
37
+ network.eval()
38
+ self.network = network
39
+ cfg = cfg if cfg is not None else network.cfg
40
+ self.cfg = cfg
41
+ self.mem_every = cfg.mem_every
42
+ stagger_updates = cfg.stagger_updates
43
+ self.chunk_size = cfg.chunk_size
44
+ self.save_aux = cfg.save_aux
45
+ self.max_internal_size = cfg.max_internal_size
46
+ self.flip_aug = cfg.flip_aug
47
+
48
+ self.curr_ti = -1
49
+ self.last_mem_ti = 0
50
+ # at which time indices should we update the sensory memory
51
+ if stagger_updates >= self.mem_every:
52
+ self.stagger_ti = set(range(1, self.mem_every + 1))
53
+ else:
54
+ self.stagger_ti = set(
55
+ np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
56
+ self.object_manager = ObjectManager()
57
+ self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
58
+
59
+ if image_feature_store is None:
60
+ self.image_feature_store = ImageFeatureStore(self.network)
61
+ else:
62
+ self.image_feature_store = image_feature_store
63
+
64
+ self.last_mask = None
65
+ self.last_pix_feat = None
66
+ self.last_msk_value = None
67
+
68
+ def clear_memory(self):
69
+ self.curr_ti = -1
70
+ self.last_mem_ti = 0
71
+ self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
72
+
73
+ def clear_non_permanent_memory(self):
74
+ self.curr_ti = -1
75
+ self.last_mem_ti = 0
76
+ self.memory.clear_non_permanent_memory()
77
+
78
+ def clear_sensory_memory(self):
79
+ self.curr_ti = -1
80
+ self.last_mem_ti = 0
81
+ self.memory.clear_sensory_memory()
82
+
83
+ def update_config(self, cfg):
84
+ self.mem_every = cfg['mem_every']
85
+ self.memory.update_config(cfg)
86
+
87
+ def clear_temp_mem(self):
88
+ self.memory.clear_work_mem()
89
+ # self.object_manager = ObjectManager()
90
+ self.memory.clear_obj_mem()
91
+ # self.memory.clear_sensory_memory()
92
+
93
+ def _add_memory(self,
94
+ image: torch.Tensor,
95
+ pix_feat: torch.Tensor,
96
+ prob: torch.Tensor,
97
+ key: torch.Tensor,
98
+ shrinkage: torch.Tensor,
99
+ selection: torch.Tensor,
100
+ *,
101
+ is_deep_update: bool = True,
102
+ force_permanent: bool = False) -> None:
103
+ """
104
+ Memorize the given segmentation in all memory stores.
105
+
106
+ The batch dimension is 1 if flip augmentation is not used.
107
+ image: RGB image, (1/2)*3*H*W
108
+ pix_feat: from the key encoder, (1/2)*_*H*W
109
+ prob: (1/2)*num_objects*H*W, in [0, 1]
110
+ key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
111
+ selection can be None if not using long-term memory
112
+ is_deep_update: whether to use deep update (e.g. with the mask encoder)
113
+ force_permanent: whether to force the memory to be permanent
114
+ """
115
+ if prob.shape[1] == 0:
116
+ # nothing to add
117
+ log.warn('Trying to add an empty object mask to memory!')
118
+ return
119
+
120
+ if force_permanent:
121
+ as_permanent = 'all'
122
+ else:
123
+ as_permanent = 'first'
124
+
125
+ self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
126
+ msk_value, sensory, obj_value, _ = self.network.encode_mask(
127
+ image,
128
+ pix_feat,
129
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
130
+ prob,
131
+ deep_update=is_deep_update,
132
+ chunk_size=self.chunk_size,
133
+ need_weights=self.save_aux)
134
+ self.memory.add_memory(key,
135
+ shrinkage,
136
+ msk_value,
137
+ obj_value,
138
+ self.object_manager.all_obj_ids,
139
+ selection=selection,
140
+ as_permanent=as_permanent)
141
+ self.last_mem_ti = self.curr_ti
142
+ if is_deep_update:
143
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
144
+ self.last_msk_value = msk_value
145
+
146
+ def _segment(self,
147
+ key: torch.Tensor,
148
+ selection: torch.Tensor,
149
+ pix_feat: torch.Tensor,
150
+ ms_features: Iterable[torch.Tensor],
151
+ update_sensory: bool = True) -> torch.Tensor:
152
+ """
153
+ Produce a segmentation using the given features and the memory
154
+
155
+ The batch dimension is 1 if flip augmentation is not used.
156
+ key/selection: for anisotropic l2: (1/2) * _ * H * W
157
+ pix_feat: from the key encoder, (1/2) * _ * H * W
158
+ ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
159
+ with strides 16, 8, and 4 respectively
160
+ update_sensory: whether to update the sensory memory
161
+
162
+ Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
163
+ """
164
+ bs = key.shape[0]
165
+ if self.flip_aug:
166
+ assert bs == 2
167
+ else:
168
+ assert bs == 1
169
+
170
+ if not self.memory.engaged:
171
+ log.warn('Trying to segment without any memory!')
172
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
173
+ device=key.device,
174
+ dtype=key.dtype)
175
+
176
+ uncert_output = None
177
+
178
+ if self.curr_ti == 0: # ONLY for the first frame for prediction
179
+ memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output)
180
+ else:
181
+ memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti,
182
+ last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask)
183
+ memory_readout = self.object_manager.realize_dict(memory_readout)
184
+
185
+ sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
186
+ memory_readout,
187
+ self.memory.get_sensory(
188
+ self.object_manager.all_obj_ids),
189
+ chunk_size=self.chunk_size,
190
+ update_sensory=update_sensory)
191
+ # remove batch dim
192
+ if self.flip_aug:
193
+ # average predictions of the non-flipped and flipped version
194
+ pred_prob_with_bg = (pred_prob_with_bg[0] +
195
+ torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
196
+ else:
197
+ pred_prob_with_bg = pred_prob_with_bg[0]
198
+ if update_sensory:
199
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
200
+ return pred_prob_with_bg
201
+
202
+ def pred_all_flow(self, images):
203
+ self.total_len = images.shape[0]
204
+ images, self.pad = pad_divide_by(images, 16)
205
+ images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w)
206
+
207
+ self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images)
208
+
209
+ def encode_all_images(self, images):
210
+ images, self.pad = pad_divide_by(images, 16)
211
+ self.image_feature_store.get_all_features(images) # t c h w
212
+ return images
213
+
214
+ def step(self,
215
+ image: torch.Tensor,
216
+ mask: Optional[torch.Tensor] = None,
217
+ objects: Optional[List[int]] = None,
218
+ *,
219
+ idx_mask: bool = False,
220
+ end: bool = False,
221
+ delete_buffer: bool = True,
222
+ force_permanent: bool = False,
223
+ matting: bool = True,
224
+ first_frame_pred: bool = False) -> torch.Tensor:
225
+ """
226
+ Take a step with a new incoming image.
227
+ If there is an incoming mask with new objects, we will memorize them.
228
+ If there is no incoming mask, we will segment the image using the memory.
229
+ In both cases, we will update the memory and return a segmentation.
230
+
231
+ image: 3*H*W
232
+ mask: H*W (if idx mask) or len(objects)*H*W or None
233
+ objects: list of object ids that are valid in the mask Tensor.
234
+ The ids themselves do not need to be consecutive/in order, but they need to be
235
+ in the same position in the list as the corresponding mask
236
+ in the tensor in non-idx-mask mode.
237
+ objects is ignored if the mask is None.
238
+ If idx_mask is False and objects is None, we sequentially infer the object ids.
239
+ idx_mask: if True, mask is expected to contain an object id at every pixel.
240
+ If False, mask should have multiple channels with each channel representing one object.
241
+ end: if we are at the end of the sequence, we do not need to update memory
242
+ if unsure just set it to False
243
+ delete_buffer: whether to delete the image feature buffer after this step
244
+ force_permanent: the memory recorded this frame will be added to the permanent memory
245
+ """
246
+ if objects is None and mask is not None:
247
+ assert not idx_mask
248
+ objects = list(range(1, mask.shape[0] + 1))
249
+
250
+ # resize input if needed -- currently only used for the GUI
251
+ resize_needed = False
252
+ if self.max_internal_size > 0:
253
+ h, w = image.shape[-2:]
254
+ min_side = min(h, w)
255
+ if min_side > self.max_internal_size:
256
+ resize_needed = True
257
+ new_h = int(h / min_side * self.max_internal_size)
258
+ new_w = int(w / min_side * self.max_internal_size)
259
+ image = F.interpolate(image.unsqueeze(0),
260
+ size=(new_h, new_w),
261
+ mode='bilinear',
262
+ align_corners=False)[0]
263
+ if mask is not None:
264
+ if idx_mask:
265
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
266
+ size=(new_h, new_w),
267
+ mode='nearest-exact',
268
+ align_corners=False)[0, 0].round().long()
269
+ else:
270
+ mask = F.interpolate(mask.unsqueeze(0),
271
+ size=(new_h, new_w),
272
+ mode='bilinear',
273
+ align_corners=False)[0]
274
+
275
+ self.curr_ti += 1
276
+
277
+ image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!!
278
+ image = image.unsqueeze(0) # add the batch dimension
279
+ if self.flip_aug:
280
+ image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
281
+
282
+ # whether to update the working memory
283
+ is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
284
+ (mask is not None)) and (not end)
285
+ # segment when there is no input mask or when the input mask is incomplete
286
+ need_segment = (mask is None) or (self.object_manager.num_obj > 0
287
+ and not self.object_manager.has_all(objects))
288
+ update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
289
+
290
+ # reinit if it is the first frame for prediction
291
+ if first_frame_pred:
292
+ self.curr_ti = 0
293
+ self.last_mem_ti = 0
294
+ is_mem_frame = True
295
+ need_segment = True
296
+ update_sensory = True
297
+
298
+ # encoding the image
299
+ ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
300
+ key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
301
+
302
+ # segmentation from memory if needed
303
+ if need_segment:
304
+ pred_prob_with_bg = self._segment(key,
305
+ selection,
306
+ pix_feat,
307
+ ms_feat,
308
+ update_sensory=update_sensory)
309
+
310
+ # use the input mask if provided
311
+ if mask is not None:
312
+ # inform the manager of the new objects, and get a list of temporary id
313
+ # temporary ids -- indicates the position of objects in the tensor
314
+ # (starts with 1 due to the background channel)
315
+ corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
316
+
317
+ mask, _ = pad_divide_by(mask, 16)
318
+ if need_segment:
319
+ # merge predicted mask with the incomplete input mask
320
+ pred_prob_no_bg = pred_prob_with_bg[1:]
321
+ # use the mutual exclusivity of segmentation
322
+ if idx_mask:
323
+ pred_prob_no_bg[:, mask > 0] = 0
324
+ else:
325
+ pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
326
+
327
+ new_masks = []
328
+ for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
329
+ if idx_mask:
330
+ this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
331
+ else:
332
+ this_mask = mask[tmp_id]
333
+ if tmp_id > pred_prob_no_bg.shape[0]:
334
+ new_masks.append(this_mask.unsqueeze(0))
335
+ else:
336
+ # +1 for padding the background channel
337
+ pred_prob_no_bg[tmp_id - 1] = this_mask
338
+ # new_masks are always in the order of tmp_id
339
+ mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
340
+ elif idx_mask:
341
+ # simply convert cls to one-hot representation
342
+ if len(objects) == 0:
343
+ if delete_buffer:
344
+ self.image_feature_store.delete(self.curr_ti)
345
+ log.warn('Trying to insert an empty mask as memory!')
346
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
347
+ device=key.device,
348
+ dtype=key.dtype)
349
+ mask = torch.stack(
350
+ [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
351
+ dim=0)
352
+ if matting:
353
+ mask = mask.unsqueeze(0).float() / 255.
354
+ pred_prob_with_bg = torch.cat([1-mask, mask], 0)
355
+ else:
356
+ pred_prob_with_bg = aggregate(mask, dim=0)
357
+ pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
358
+
359
+ self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
360
+ if self.flip_aug:
361
+ self.last_mask = torch.cat(
362
+ [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
363
+ self.last_pix_feat = pix_feat
364
+
365
+ # save as memory if needed
366
+ if is_mem_frame or force_permanent:
367
+ # clear the memory for given mask and add the first predicted mask
368
+ if first_frame_pred:
369
+ self.clear_temp_mem()
370
+ self._add_memory(image,
371
+ pix_feat,
372
+ self.last_mask,
373
+ key,
374
+ shrinkage,
375
+ selection,
376
+ force_permanent=force_permanent,
377
+ is_deep_update=True)
378
+ else: # compute self.last_msk_value for non-memory frame
379
+ msk_value, _, _, _ = self.network.encode_mask(
380
+ image,
381
+ pix_feat,
382
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
383
+ self.last_mask,
384
+ deep_update=False,
385
+ chunk_size=self.chunk_size,
386
+ need_weights=self.save_aux)
387
+ self.last_msk_value = msk_value
388
+
389
+ if delete_buffer:
390
+ self.image_feature_store.delete(self.curr_ti)
391
+
392
+ output_prob = unpad(pred_prob_with_bg, self.pad)
393
+ if resize_needed:
394
+ # restore output to the original size
395
+ output_prob = F.interpolate(output_prob.unsqueeze(0),
396
+ size=(h, w),
397
+ mode='bilinear',
398
+ align_corners=False)[0]
399
+
400
+ return output_prob
401
+
402
+ def delete_objects(self, objects: List[int]) -> None:
403
+ """
404
+ Delete the given objects from the memory.
405
+ """
406
+ self.object_manager.delete_objects(objects)
407
+ self.memory.purge_except(self.object_manager.all_obj_ids)
408
+
409
+ def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor:
410
+ if matting:
411
+ new_mask = output_prob[1:].squeeze(0)
412
+ else:
413
+ mask = torch.argmax(output_prob, dim=0)
414
+
415
+ # index in tensor != object id -- remap the ids here
416
+ new_mask = torch.zeros_like(mask)
417
+ for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
418
+ new_mask[mask == tmp_id] = obj.id
419
+
420
+ return new_mask
421
+
422
+ @torch.inference_mode()
423
+ @torch.amp.autocast("cuda")
424
+ def process_video(
425
+ self,
426
+ input_path: str,
427
+ mask_path: str,
428
+ output_path: str = None,
429
+ n_warmup: int = 10,
430
+ r_erode: int = 10,
431
+ r_dilate: int = 10,
432
+ suffix: str = "",
433
+ save_image: bool = False,
434
+ max_size: int = -1,
435
+ ) -> Tuple:
436
+ """
437
+ Process a video for object segmentation and matting.
438
+ This method processes a video file by performing object segmentation and matting on each frame.
439
+ It supports warmup frames, mask erosion/dilation, and various output options.
440
+ Args:
441
+ input_path (str): Path to the input video file
442
+ mask_path (str): Path to the mask image file used for initial segmentation
443
+ output_path (str, optional): Directory path where output files will be saved. Defaults to a temporary directory
444
+ n_warmup (int, optional): Number of warmup frames to use. Defaults to 10
445
+ r_erode (int, optional): Erosion radius for mask processing. Defaults to 10
446
+ r_dilate (int, optional): Dilation radius for mask processing. Defaults to 10
447
+ suffix (str, optional): Suffix to append to output filename. Defaults to ""
448
+ save_image (bool, optional): Whether to save individual frames. Defaults to False
449
+ max_size (int, optional): Maximum size for frame dimension. Use -1 for no limit. Defaults to -1
450
+ Returns:
451
+ Tuple[str, str]: A tuple containing:
452
+ - Path to the output foreground video file (str)
453
+ - Path to the output alpha matte video file (str)
454
+ Output:
455
+ - Saves processed video files with foreground (_fgr) and alpha matte (_pha)
456
+ - If save_image=True, saves individual frames in separate directories
457
+ """
458
+ output_path = output_path if output_path is not None else tempfile.TemporaryDirectory().name
459
+ r_erode = int(r_erode)
460
+ r_dilate = int(r_dilate)
461
+ n_warmup = int(n_warmup)
462
+ max_size = int(max_size)
463
+
464
+ vframes, fps, length, video_name = read_frame_from_videos(input_path)
465
+ repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1)
466
+ vframes = torch.cat([repeated_frames, vframes], dim=0).float()
467
+ length += n_warmup
468
+
469
+ new_h, new_w = vframes.shape[-2:]
470
+ if max_size > 0:
471
+ h, w = new_h, new_w
472
+ min_side = min(h, w)
473
+ if min_side > max_size:
474
+ new_h = int(h / min_side * max_size)
475
+ new_w = int(w / min_side * max_size)
476
+ vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area")
477
+
478
+ os.makedirs(output_path, exist_ok=True)
479
+ if suffix:
480
+ video_name = f"{video_name}_{suffix}"
481
+ if save_image:
482
+ os.makedirs(f"{output_path}/{video_name}", exist_ok=True)
483
+ os.makedirs(f"{output_path}/{video_name}/pha", exist_ok=True)
484
+ os.makedirs(f"{output_path}/{video_name}/fgr", exist_ok=True)
485
+
486
+ mask = np.array(Image.open(mask_path).convert("L"))
487
+ if r_dilate > 0:
488
+ mask = gen_dilate(mask, r_dilate, r_dilate)
489
+ if r_erode > 0:
490
+ mask = gen_erosion(mask, r_erode, r_erode)
491
+
492
+ mask = torch.from_numpy(mask).cuda()
493
+ if max_size > 0:
494
+ mask = F.interpolate(
495
+ mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest"
496
+ )[0, 0]
497
+
498
+ bgr = (np.array([120, 255, 155], dtype=np.float32) / 255).reshape((1, 1, 3))
499
+ objects = [1]
500
+
501
+ phas = []
502
+ fgrs = []
503
+ for ti in tqdm(range(length)):
504
+ image = vframes[ti]
505
+ image_np = np.array(image.permute(1, 2, 0))
506
+ image = (image / 255.0).cuda().float()
507
+
508
+ if ti == 0:
509
+ output_prob = self.step(image, mask, objects=objects)
510
+ output_prob = self.step(image, first_frame_pred=True)
511
+ else:
512
+ if ti <= n_warmup:
513
+ output_prob = self.step(image, first_frame_pred=True)
514
+ else:
515
+ output_prob = self.step(image)
516
+
517
+ mask = self.output_prob_to_mask(output_prob)
518
+ pha = mask.unsqueeze(2).cpu().numpy()
519
+ com_np = image_np / 255.0 * pha + bgr * (1 - pha)
520
+
521
+ if ti > (n_warmup - 1):
522
+ com_np = (com_np * 255).astype(np.uint8)
523
+ pha = (pha * 255).astype(np.uint8)
524
+ fgrs.append(com_np)
525
+ phas.append(pha)
526
+ if save_image:
527
+ cv2.imwrite(
528
+ f"{output_path}/{video_name}/pha/{str(ti - n_warmup).zfill(5)}.png",
529
+ pha,
530
+ )
531
+ cv2.imwrite(
532
+ f"{output_path}/{video_name}/fgr/{str(ti - n_warmup).zfill(5)}.png",
533
+ com_np[..., [2, 1, 0]],
534
+ )
535
+
536
+ fgrs = np.array(fgrs)
537
+ phas = np.array(phas)
538
+
539
+ fgr_filename = f"{output_path}/{video_name}_fgr.mp4"
540
+ alpha_filename = f"{output_path}/{video_name}_pha.mp4"
541
+
542
+ imageio.mimwrite(fgr_filename, fgrs, fps=fps, quality=7)
543
+ imageio.mimwrite(alpha_filename, phas, fps=fps, quality=7)
544
+
545
+ return (fgr_filename,alpha_filename)
hf_space/third_party/matanyone/inference/kv_memory_store.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Literal
2
+ from collections import defaultdict
3
+ import torch
4
+
5
+
6
+ def _add_last_dim(dictionary, key, new_value, prepend=False):
7
+ # append/prepend a new value to the last dimension of a tensor in a dictionary
8
+ # if the key does not exist, put the new value in
9
+ # append by default
10
+ if key in dictionary:
11
+ dictionary[key] = torch.cat([dictionary[key], new_value], -1)
12
+ else:
13
+ dictionary[key] = new_value
14
+
15
+
16
+ class KeyValueMemoryStore:
17
+ """
18
+ Works for key/value pairs type storage
19
+ e.g., working and long-term memory
20
+ """
21
+ def __init__(self, save_selection: bool = False, save_usage: bool = False):
22
+ """
23
+ We store keys and values of objects that first appear in the same frame in a bucket.
24
+ Each bucket contains a set of object ids.
25
+ Each bucket is associated with a single key tensor
26
+ and a dictionary of value tensors indexed by object id.
27
+
28
+ The keys and values are stored as the concatenation of a permanent part and a temporary part.
29
+ """
30
+ self.save_selection = save_selection
31
+ self.save_usage = save_usage
32
+
33
+ self.global_bucket_id = 0 # does not reduce even if buckets are removed
34
+ self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
35
+ self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
36
+ self.v: Dict[int, torch.Tensor] = {} # indexed by object id
37
+
38
+ # indexed by bucket id; the end point of permanent memory
39
+ self.perm_end_pt: Dict[int, int] = defaultdict(int)
40
+
41
+ # shrinkage and selection are just like the keys
42
+ self.s = {}
43
+ if self.save_selection:
44
+ self.e = {} # does not contain the permanent memory part
45
+
46
+ # usage
47
+ if self.save_usage:
48
+ self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
49
+ self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
50
+
51
+ def add(self,
52
+ key: torch.Tensor,
53
+ values: Dict[int, torch.Tensor],
54
+ shrinkage: torch.Tensor,
55
+ selection: torch.Tensor,
56
+ supposed_bucket_id: int = -1,
57
+ as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
58
+ """
59
+ key: (1/2)*C*N
60
+ values: dict of values ((1/2)*C*N), object ids are used as keys
61
+ shrinkage: (1/2)*1*N
62
+ selection: (1/2)*C*N
63
+
64
+ supposed_bucket_id: used to sync the bucket id between working and long-term memory
65
+ if provided, the input should all be in a single bucket indexed by this id
66
+ as_permanent: whether to store the input as permanent memory
67
+ 'no': don't
68
+ 'first': only store it as permanent memory if the bucket is empty
69
+ 'all': always store it as permanent memory
70
+ """
71
+ bs = key.shape[0]
72
+ ne = key.shape[-1]
73
+ assert len(key.shape) == 3
74
+ assert len(shrinkage.shape) == 3
75
+ assert not self.save_selection or len(selection.shape) == 3
76
+ assert as_permanent in ['no', 'first', 'all']
77
+
78
+ # add the value and create new buckets if necessary
79
+ if supposed_bucket_id >= 0:
80
+ enabled_buckets = [supposed_bucket_id]
81
+ bucket_exist = supposed_bucket_id in self.buckets
82
+ for obj, value in values.items():
83
+ if bucket_exist:
84
+ assert obj in self.v
85
+ assert obj in self.buckets[supposed_bucket_id]
86
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
87
+ else:
88
+ assert obj not in self.v
89
+ self.v[obj] = value
90
+ self.buckets[supposed_bucket_id] = list(values.keys())
91
+ else:
92
+ new_bucket_id = None
93
+ enabled_buckets = set()
94
+ for obj, value in values.items():
95
+ assert len(value.shape) == 3
96
+ if obj in self.v:
97
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
98
+ bucket_used = [
99
+ bucket_id for bucket_id, object_ids in self.buckets.items()
100
+ if obj in object_ids
101
+ ]
102
+ assert len(bucket_used) == 1 # each object should only be in one bucket
103
+ enabled_buckets.add(bucket_used[0])
104
+ else:
105
+ self.v[obj] = value
106
+ if new_bucket_id is None:
107
+ # create new bucket
108
+ new_bucket_id = self.global_bucket_id
109
+ self.global_bucket_id += 1
110
+ self.buckets[new_bucket_id] = []
111
+ # put the new object into the corresponding bucket
112
+ self.buckets[new_bucket_id].append(obj)
113
+ enabled_buckets.add(new_bucket_id)
114
+
115
+ # increment the permanent size if necessary
116
+ add_as_permanent = {} # indexed by bucket id
117
+ for bucket_id in enabled_buckets:
118
+ add_as_permanent[bucket_id] = False
119
+ if as_permanent == 'all':
120
+ self.perm_end_pt[bucket_id] += ne
121
+ add_as_permanent[bucket_id] = True
122
+ elif as_permanent == 'first':
123
+ if self.perm_end_pt[bucket_id] == 0:
124
+ self.perm_end_pt[bucket_id] = ne
125
+ add_as_permanent[bucket_id] = True
126
+
127
+ # create new counters for usage if necessary
128
+ if self.save_usage and as_permanent != 'all':
129
+ new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
130
+ new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
131
+
132
+ # add the key to every bucket
133
+ for bucket_id in self.buckets:
134
+ if bucket_id not in enabled_buckets:
135
+ # if we are not adding new values to a bucket, we should skip it
136
+ continue
137
+
138
+ _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
139
+ _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
140
+ if not add_as_permanent[bucket_id]:
141
+ if self.save_selection:
142
+ _add_last_dim(self.e, bucket_id, selection)
143
+ if self.save_usage:
144
+ _add_last_dim(self.use_cnt, bucket_id, new_count)
145
+ _add_last_dim(self.life_cnt, bucket_id, new_life)
146
+
147
+ def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
148
+ # increase all life count by 1
149
+ # increase use of indexed elements
150
+ if not self.save_usage:
151
+ return
152
+
153
+ usage = usage[:, self.perm_end_pt[bucket_id]:]
154
+ if usage.shape[-1] == 0:
155
+ # if there is no temporary memory, we don't need to update
156
+ return
157
+ self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
158
+ self.life_cnt[bucket_id] += 1
159
+
160
+ def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
161
+ # keep only the temporary elements *outside* of this range (with some boundary conditions)
162
+ # the permanent elements are ignored in this computation
163
+ # i.e., concat (a[:start], a[end:])
164
+ # bucket with size <= min_size are not modified
165
+
166
+ assert start >= 0
167
+ assert end <= 0
168
+
169
+ object_ids = self.buckets[bucket_id]
170
+ bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
171
+ if bucket_num_elements <= min_size:
172
+ return
173
+
174
+ if end == 0:
175
+ # negative 0 would not work as the end index!
176
+ # effectively make the second part an empty slice
177
+ end = self.k[bucket_id].shape[-1] + 1
178
+
179
+ p_size = self.perm_end_pt[bucket_id]
180
+ start = start + p_size
181
+
182
+ k = self.k[bucket_id]
183
+ s = self.s[bucket_id]
184
+ if self.save_selection:
185
+ e = self.e[bucket_id]
186
+ if self.save_usage:
187
+ use_cnt = self.use_cnt[bucket_id]
188
+ life_cnt = self.life_cnt[bucket_id]
189
+
190
+ self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
191
+ self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
192
+ if self.save_selection:
193
+ self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
194
+ if self.save_usage:
195
+ self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
196
+ self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
197
+ -1)
198
+ for obj_id in object_ids:
199
+ v = self.v[obj_id]
200
+ self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
201
+
202
+ def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
203
+ self.sieve_by_range(bucket_id, 0, -max_len, max_len)
204
+
205
+ def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
206
+ # for long-term memory only
207
+ object_ids = self.buckets[bucket_id]
208
+
209
+ assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
210
+
211
+ # normalize with life duration
212
+ usage = self.get_usage(bucket_id)
213
+ bs = usage.shape[0]
214
+
215
+ survivals = []
216
+
217
+ for bi in range(bs):
218
+ _, survived = torch.topk(usage[bi], k=max_size)
219
+ survivals.append(survived.flatten())
220
+ assert survived.shape[-1] == survivals[0].shape[-1]
221
+
222
+ self.k[bucket_id] = torch.stack(
223
+ [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
224
+ self.s[bucket_id] = torch.stack(
225
+ [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
226
+
227
+ if self.save_selection:
228
+ # Long-term memory does not store selection so this should not be needed
229
+ self.e[bucket_id] = torch.stack(
230
+ [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
231
+ for obj_id in object_ids:
232
+ self.v[obj_id] = torch.stack(
233
+ [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
234
+
235
+ self.use_cnt[bucket_id] = torch.stack(
236
+ [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
237
+ self.life_cnt[bucket_id] = torch.stack(
238
+ [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
239
+
240
+ def get_usage(self, bucket_id: int) -> torch.Tensor:
241
+ # return normalized usage
242
+ if not self.save_usage:
243
+ raise RuntimeError('I did not count usage!')
244
+ else:
245
+ usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
246
+ return usage
247
+
248
+ def get_all_sliced(
249
+ self, bucket_id: int, start: int, end: int
250
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
251
+ # return k, sk, ek, value, normalized usage in order, sliced by start and end
252
+ # this only queries the temporary memory
253
+
254
+ assert start >= 0
255
+ assert end <= 0
256
+
257
+ p_size = self.perm_end_pt[bucket_id]
258
+ start = start + p_size
259
+
260
+ if end == 0:
261
+ # negative 0 would not work as the end index!
262
+ k = self.k[bucket_id][:, :, start:]
263
+ sk = self.s[bucket_id][:, :, start:]
264
+ ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
265
+ value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
266
+ usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
267
+ else:
268
+ k = self.k[bucket_id][:, :, start:end]
269
+ sk = self.s[bucket_id][:, :, start:end]
270
+ ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
271
+ value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
272
+ usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
273
+
274
+ return k, sk, ek, value, usage
275
+
276
+ def purge_except(self, obj_keep_idx: List[int]):
277
+ # purge certain objects from the memory except the one listed
278
+ obj_keep_idx = set(obj_keep_idx)
279
+
280
+ # remove objects that are not in the keep list from the buckets
281
+ buckets_to_remove = []
282
+ for bucket_id, object_ids in self.buckets.items():
283
+ self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
284
+ if len(self.buckets[bucket_id]) == 0:
285
+ buckets_to_remove.append(bucket_id)
286
+
287
+ # remove object values that are not in the keep list
288
+ self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
289
+
290
+ # remove buckets that are empty
291
+ for bucket_id in buckets_to_remove:
292
+ del self.buckets[bucket_id]
293
+ del self.k[bucket_id]
294
+ del self.s[bucket_id]
295
+ if self.save_selection:
296
+ del self.e[bucket_id]
297
+ if self.save_usage:
298
+ del self.use_cnt[bucket_id]
299
+ del self.life_cnt[bucket_id]
300
+
301
+ def clear_non_permanent_memory(self):
302
+ # clear all non-permanent memory
303
+ for bucket_id in self.buckets:
304
+ self.sieve_by_range(bucket_id, 0, 0, 0)
305
+
306
+ def get_v_size(self, obj_id: int) -> int:
307
+ return self.v[obj_id].shape[-1]
308
+
309
+ def size(self, bucket_id: int) -> int:
310
+ if bucket_id not in self.k:
311
+ return 0
312
+ else:
313
+ return self.k[bucket_id].shape[-1]
314
+
315
+ def perm_size(self, bucket_id: int) -> int:
316
+ return self.perm_end_pt[bucket_id]
317
+
318
+ def non_perm_size(self, bucket_id: int) -> int:
319
+ return self.size(bucket_id) - self.perm_size(bucket_id)
320
+
321
+ def engaged(self, bucket_id: Optional[int] = None) -> bool:
322
+ if bucket_id is None:
323
+ return len(self.buckets) > 0
324
+ else:
325
+ return bucket_id in self.buckets
326
+
327
+ @property
328
+ def num_objects(self) -> int:
329
+ return len(self.v)
330
+
331
+ @property
332
+ def key(self) -> Dict[int, torch.Tensor]:
333
+ return self.k
334
+
335
+ @property
336
+ def value(self) -> Dict[int, torch.Tensor]:
337
+ return self.v
338
+
339
+ @property
340
+ def shrinkage(self) -> Dict[int, torch.Tensor]:
341
+ return self.s
342
+
343
+ @property
344
+ def selection(self) -> Dict[int, torch.Tensor]:
345
+ return self.e
346
+
347
+ def __contains__(self, key):
348
+ return key in self.v
hf_space/third_party/matanyone/inference/memory_manager.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+ from typing import List, Dict
4
+ import torch
5
+
6
+ from matanyone.inference.object_manager import ObjectManager
7
+ from matanyone.inference.kv_memory_store import KeyValueMemoryStore
8
+ from matanyone.model.matanyone import MatAnyone
9
+ from matanyone.model.utils.memory_utils import get_similarity, do_softmax
10
+
11
+ log = logging.getLogger()
12
+
13
+
14
+ class MemoryManager:
15
+ """
16
+ Manages all three memory stores and the transition between working/long-term memory
17
+ """
18
+ def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
19
+ self.object_manager = object_manager
20
+ self.sensory_dim = cfg.model.sensory_dim
21
+ self.top_k = cfg.top_k
22
+ self.chunk_size = cfg.chunk_size
23
+
24
+ self.save_aux = cfg.save_aux
25
+
26
+ self.use_long_term = cfg.use_long_term
27
+ self.count_long_term_usage = cfg.long_term.count_usage
28
+ # subtract 1 because the first-frame is now counted as "permanent memory"
29
+ # and is not counted towards max_mem_frames
30
+ # but we want to keep the hyperparameters consistent as before for the same behavior
31
+ if self.use_long_term:
32
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
33
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
34
+ self.num_prototypes = cfg.long_term.num_prototypes
35
+ self.max_long_tokens = cfg.long_term.max_num_tokens
36
+ self.buffer_tokens = cfg.long_term.buffer_tokens
37
+ else:
38
+ self.max_mem_frames = cfg.max_mem_frames - 1
39
+
40
+ # dimensions will be inferred from input later
41
+ self.CK = self.CV = None
42
+ self.H = self.W = None
43
+
44
+ # The sensory memory is stored as a dictionary indexed by object ids
45
+ # each of shape bs * C^h * H * W
46
+ self.sensory = {}
47
+
48
+ # a dictionary indexed by object ids, each of shape bs * T * Q * C
49
+ self.obj_v = {}
50
+
51
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
52
+ save_usage=self.use_long_term)
53
+ if self.use_long_term:
54
+ self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
55
+
56
+ self.config_stale = True
57
+ self.engaged = False
58
+
59
+ def update_config(self, cfg: DictConfig) -> None:
60
+ self.config_stale = True
61
+ self.top_k = cfg['top_k']
62
+
63
+ assert self.use_long_term == cfg.use_long_term, 'cannot update this'
64
+ assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
65
+
66
+ self.use_long_term = cfg.use_long_term
67
+ self.count_long_term_usage = cfg.long_term.count_usage
68
+ if self.use_long_term:
69
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
70
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
71
+ self.num_prototypes = cfg.long_term.num_prototypes
72
+ self.max_long_tokens = cfg.long_term.max_num_tokens
73
+ self.buffer_tokens = cfg.long_term.buffer_tokens
74
+ else:
75
+ self.max_mem_frames = cfg.max_mem_frames - 1
76
+
77
+ def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor:
78
+ # affinity: bs*N*HW
79
+ # v: bs*C*N or bs*num_objects*C*N
80
+ # returns bs*C*HW or bs*num_objects*C*HW
81
+ if len(v.shape) == 3:
82
+ # single object
83
+ if uncert_mask is not None:
84
+ return v @ affinity * uncert_mask
85
+ else:
86
+ return v @ affinity
87
+ else:
88
+ bs, num_objects, C, N = v.shape
89
+ v = v.view(bs, num_objects * C, N)
90
+ out = v @ affinity
91
+ if uncert_mask is not None:
92
+ uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1)
93
+ out = out * uncert_mask
94
+ return out.view(bs, num_objects, C, -1)
95
+
96
+ def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
97
+ # -1 because the mask does not contain the background channel
98
+ return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
99
+
100
+ def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
101
+ return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
102
+
103
+ def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
104
+ return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
105
+
106
+ def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
107
+ # All the values that the object ids refer to should have the same shape
108
+ value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
109
+ if self.use_long_term and obj_ids[0] in self.long_mem.value:
110
+ lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
111
+ value = torch.cat([lt_value, value], dim=-1)
112
+
113
+ return value
114
+
115
+ def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
116
+ last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
117
+ """
118
+ Read from all memory stores and returns a single memory readout tensor for each object
119
+
120
+ pix_feat: (1/2) x C x H x W
121
+ query_key: (1/2) x C^k x H x W
122
+ selection: (1/2) x C^k x H x W
123
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
124
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
125
+ """
126
+ h, w = pix_feat.shape[-2:]
127
+ bs = pix_feat.shape[0]
128
+ assert last_mask.shape[0] == bs
129
+
130
+ """
131
+ Compute affinity and perform readout
132
+ """
133
+ all_readout_mem = {}
134
+ buckets = self.work_mem.buckets
135
+ for bucket_id, bucket in buckets.items():
136
+
137
+ if self.chunk_size < 1:
138
+ object_chunks = [bucket]
139
+ else:
140
+ object_chunks = [
141
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
142
+ ]
143
+
144
+ for objects in object_chunks:
145
+ this_sensory = self._get_sensory_by_ids(objects)
146
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
147
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
148
+ pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory,
149
+ this_last_mask)
150
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
151
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
152
+ for i, obj in enumerate(objects):
153
+ all_readout_mem[obj] = readout_memory[:, i]
154
+
155
+ if self.save_aux:
156
+ aux_output = {
157
+ # 'sensory': this_sensory,
158
+ # 'pixel_readout': pixel_readout,
159
+ 'q_logits': aux_features['logits'] if aux_features else None,
160
+ # 'q_weights': aux_features['q_weights'] if aux_features else None,
161
+ # 'p_weights': aux_features['p_weights'] if aux_features else None,
162
+ # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
163
+ }
164
+ self.aux = aux_output
165
+
166
+ return all_readout_mem
167
+
168
+ def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
169
+ last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
170
+ last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
171
+ """
172
+ Read from all memory stores and returns a single memory readout tensor for each object
173
+
174
+ pix_feat: (1/2) x C x H x W
175
+ query_key: (1/2) x C^k x H x W
176
+ selection: (1/2) x C^k x H x W
177
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
178
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
179
+ """
180
+ h, w = pix_feat.shape[-2:]
181
+ bs = pix_feat.shape[0]
182
+ assert query_key.shape[0] == bs
183
+ assert selection.shape[0] == bs
184
+ assert last_mask.shape[0] == bs
185
+
186
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
187
+
188
+ query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
189
+ selection = selection.flatten(start_dim=2) # bs*C^k*HW
190
+ """
191
+ Compute affinity and perform readout
192
+ """
193
+ all_readout_mem = {}
194
+ buckets = self.work_mem.buckets
195
+ for bucket_id, bucket in buckets.items():
196
+ if self.use_long_term and self.long_mem.engaged(bucket_id):
197
+ # Use long-term memory
198
+ long_mem_size = self.long_mem.size(bucket_id)
199
+ memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
200
+ -1)
201
+ shrinkage = torch.cat(
202
+ [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
203
+
204
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
205
+ affinity, usage = do_softmax(similarity,
206
+ top_k=self.top_k,
207
+ inplace=True,
208
+ return_usage=True)
209
+ """
210
+ Record memory usage for working and long-term memory
211
+ """
212
+ # ignore the index return for long-term memory
213
+ work_usage = usage[:, long_mem_size:]
214
+ self.work_mem.update_bucket_usage(bucket_id, work_usage)
215
+
216
+ if self.count_long_term_usage:
217
+ # ignore the index return for working memory
218
+ long_usage = usage[:, :long_mem_size]
219
+ self.long_mem.update_bucket_usage(bucket_id, long_usage)
220
+ else:
221
+ # no long-term memory
222
+ memory_key = self.work_mem.key[bucket_id]
223
+ shrinkage = self.work_mem.shrinkage[bucket_id]
224
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask)
225
+
226
+ if self.use_long_term:
227
+ affinity, usage = do_softmax(similarity,
228
+ top_k=self.top_k,
229
+ inplace=True,
230
+ return_usage=True)
231
+ self.work_mem.update_bucket_usage(bucket_id, usage)
232
+ else:
233
+ affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
234
+
235
+ if self.chunk_size < 1:
236
+ object_chunks = [bucket]
237
+ else:
238
+ object_chunks = [
239
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
240
+ ]
241
+
242
+ for objects in object_chunks:
243
+ this_sensory = self._get_sensory_by_ids(objects)
244
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
245
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
246
+ visual_readout = self._readout(affinity,
247
+ this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w)
248
+
249
+ uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0])
250
+
251
+ if uncert_output is not None:
252
+ uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
253
+ visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob)
254
+
255
+ pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
256
+ this_last_mask)
257
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
258
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
259
+ for i, obj in enumerate(objects):
260
+ all_readout_mem[obj] = readout_memory[:, i]
261
+
262
+ if self.save_aux:
263
+ aux_output = {
264
+ # 'sensory': this_sensory,
265
+ # 'pixel_readout': pixel_readout,
266
+ 'q_logits': aux_features['logits'] if aux_features else None,
267
+ # 'q_weights': aux_features['q_weights'] if aux_features else None,
268
+ # 'p_weights': aux_features['p_weights'] if aux_features else None,
269
+ # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
270
+ }
271
+ self.aux = aux_output
272
+
273
+ return all_readout_mem
274
+
275
+ def add_memory(self,
276
+ key: torch.Tensor,
277
+ shrinkage: torch.Tensor,
278
+ msk_value: torch.Tensor,
279
+ obj_value: torch.Tensor,
280
+ objects: List[int],
281
+ selection: torch.Tensor = None,
282
+ *,
283
+ as_permanent: bool = False) -> None:
284
+ # key: (1/2)*C*H*W
285
+ # msk_value: (1/2)*num_objects*C*H*W
286
+ # obj_value: (1/2)*num_objects*Q*C
287
+ # objects contains a list of object ids corresponding to the objects in msk_value/obj_value
288
+ bs = key.shape[0]
289
+ assert shrinkage.shape[0] == bs
290
+ assert msk_value.shape[0] == bs
291
+ assert obj_value.shape[0] == bs
292
+
293
+ self.engaged = True
294
+ if self.H is None or self.config_stale:
295
+ self.config_stale = False
296
+ self.H, self.W = msk_value.shape[-2:]
297
+ self.HW = self.H * self.W
298
+ # convert from num. frames to num. tokens
299
+ self.max_work_tokens = self.max_mem_frames * self.HW
300
+ if self.use_long_term:
301
+ self.min_work_tokens = self.min_mem_frames * self.HW
302
+
303
+ # key: bs*C*N
304
+ # value: bs*num_objects*C*N
305
+ key = key.flatten(start_dim=2)
306
+ shrinkage = shrinkage.flatten(start_dim=2)
307
+ self.CK = key.shape[1]
308
+
309
+ msk_value = msk_value.flatten(start_dim=3)
310
+ self.CV = msk_value.shape[2]
311
+
312
+ if selection is not None:
313
+ # not used in non-long-term mode
314
+ selection = selection.flatten(start_dim=2)
315
+
316
+ # insert object values into object memory
317
+ for obj_id, obj in enumerate(objects):
318
+ if obj in self.obj_v:
319
+ """streaming average
320
+ each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
321
+ first embed_dim keeps track of the sum of embeddings
322
+ the last dim keeps the total count
323
+ averaging in done inside the object transformer
324
+
325
+ incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
326
+ self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
327
+ """
328
+ last_acc = self.obj_v[obj][:, :, -1]
329
+ new_acc = last_acc + obj_value[:, obj_id, :, -1]
330
+
331
+ self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
332
+ obj_value[:, obj_id, :, :-1])
333
+ self.obj_v[obj][:, :, -1] = new_acc
334
+ else:
335
+ self.obj_v[obj] = obj_value[:, obj_id]
336
+
337
+ # convert mask value tensor into a dict for insertion
338
+ msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
339
+ self.work_mem.add(key,
340
+ msk_values,
341
+ shrinkage,
342
+ selection=selection,
343
+ as_permanent=as_permanent)
344
+
345
+ for bucket_id in self.work_mem.buckets.keys():
346
+ # long-term memory cleanup
347
+ if self.use_long_term:
348
+ # Do memory compressed if needed
349
+ if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
350
+ # Remove obsolete features if needed
351
+ if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
352
+ self.num_prototypes):
353
+ self.long_mem.remove_obsolete_features(
354
+ bucket_id,
355
+ self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
356
+
357
+ self.compress_features(bucket_id)
358
+ else:
359
+ # FIFO
360
+ self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
361
+
362
+ def purge_except(self, obj_keep_idx: List[int]) -> None:
363
+ # purge certain objects from the memory except the one listed
364
+ self.work_mem.purge_except(obj_keep_idx)
365
+ if self.use_long_term and self.long_mem.engaged():
366
+ self.long_mem.purge_except(obj_keep_idx)
367
+ self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
368
+
369
+ if not self.work_mem.engaged():
370
+ # everything is removed!
371
+ self.engaged = False
372
+
373
+ def compress_features(self, bucket_id: int) -> None:
374
+
375
+ # perform memory consolidation
376
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
377
+ *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
378
+
379
+ # remove consolidated working memory
380
+ self.work_mem.sieve_by_range(bucket_id,
381
+ 0,
382
+ -self.min_work_tokens,
383
+ min_size=self.min_work_tokens)
384
+
385
+ # add to long-term memory
386
+ self.long_mem.add(prototype_key,
387
+ prototype_value,
388
+ prototype_shrinkage,
389
+ selection=None,
390
+ supposed_bucket_id=bucket_id)
391
+
392
+ def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
393
+ candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
394
+ usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
395
+ # find the indices with max usage
396
+ bs = candidate_key.shape[0]
397
+ assert bs in [1, 2]
398
+
399
+ prototype_key = []
400
+ prototype_selection = []
401
+ for bi in range(bs):
402
+ _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
403
+ prototype_indices = max_usage_indices.flatten()
404
+ prototype_key.append(candidate_key[bi, :, prototype_indices])
405
+ prototype_selection.append(candidate_selection[bi, :, prototype_indices])
406
+ prototype_key = torch.stack(prototype_key, dim=0)
407
+ prototype_selection = torch.stack(prototype_selection, dim=0)
408
+ """
409
+ Potentiation step
410
+ """
411
+ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
412
+ prototype_selection)
413
+ affinity = do_softmax(similarity)
414
+
415
+ # readout the values
416
+ prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
417
+
418
+ # readout the shrinkage term
419
+ prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
420
+
421
+ return prototype_key, prototype_value, prototype_shrinkage
422
+
423
+ def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
424
+ for obj in ids:
425
+ if obj not in self.sensory:
426
+ # also initializes the sensory memory
427
+ bs, _, h, w = sample_key.shape
428
+ self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
429
+ device=sample_key.device)
430
+
431
+ def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
432
+ # sensory: 1*num_objects*C*H*W
433
+ for obj_id, obj in enumerate(ids):
434
+ self.sensory[obj] = sensory[:, obj_id]
435
+
436
+ def get_sensory(self, ids: List[int]):
437
+ # returns (1/2)*num_objects*C*H*W
438
+ return self._get_sensory_by_ids(ids)
439
+
440
+ def clear_non_permanent_memory(self):
441
+ self.work_mem.clear_non_permanent_memory()
442
+ if self.use_long_term:
443
+ self.long_mem.clear_non_permanent_memory()
444
+
445
+ def clear_sensory_memory(self):
446
+ self.sensory = {}
447
+
448
+ def clear_work_mem(self):
449
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
450
+ save_usage=self.use_long_term)
451
+
452
+ def clear_obj_mem(self):
453
+ self.obj_v = {}
hf_space/third_party/matanyone/inference/object_info.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ObjectInfo:
2
+ """
3
+ Store meta information for an object
4
+ """
5
+ def __init__(self, id: int):
6
+ self.id = id
7
+ self.poke_count = 0 # count number of detections missed
8
+
9
+ def poke(self) -> None:
10
+ self.poke_count += 1
11
+
12
+ def unpoke(self) -> None:
13
+ self.poke_count = 0
14
+
15
+ def __hash__(self):
16
+ return hash(self.id)
17
+
18
+ def __eq__(self, other):
19
+ if type(other) == int:
20
+ return self.id == other
21
+ return self.id == other.id
22
+
23
+ def __repr__(self):
24
+ return f'(ID: {self.id})'
hf_space/third_party/matanyone/inference/object_manager.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Dict
2
+
3
+ import torch
4
+ from matanyone.inference.object_info import ObjectInfo
5
+
6
+
7
+ class ObjectManager:
8
+ """
9
+ Object IDs are immutable. The same ID always represent the same object.
10
+ Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
11
+ Temporary IDs start from 1.
12
+ """
13
+
14
+ def __init__(self):
15
+ self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
16
+ self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
17
+ self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
18
+
19
+ self.all_historical_object_ids: List[int] = []
20
+
21
+ def _recompute_obj_id_to_obj_mapping(self) -> None:
22
+ self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
23
+
24
+ def add_new_objects(
25
+ self, objects: Union[List[ObjectInfo], ObjectInfo,
26
+ List[int]]) -> (List[int], List[int]):
27
+ if not isinstance(objects, list):
28
+ objects = [objects]
29
+
30
+ corresponding_tmp_ids = []
31
+ corresponding_obj_ids = []
32
+ for obj in objects:
33
+ if isinstance(obj, int):
34
+ obj = ObjectInfo(id=obj)
35
+
36
+ if obj in self.obj_to_tmp_id:
37
+ # old object
38
+ corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
39
+ corresponding_obj_ids.append(obj.id)
40
+ else:
41
+ # new object
42
+ new_obj = ObjectInfo(id=obj.id)
43
+
44
+ # new object
45
+ new_tmp_id = len(self.obj_to_tmp_id) + 1
46
+ self.obj_to_tmp_id[new_obj] = new_tmp_id
47
+ self.tmp_id_to_obj[new_tmp_id] = new_obj
48
+ self.all_historical_object_ids.append(new_obj.id)
49
+ corresponding_tmp_ids.append(new_tmp_id)
50
+ corresponding_obj_ids.append(new_obj.id)
51
+
52
+ self._recompute_obj_id_to_obj_mapping()
53
+ assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
54
+ return corresponding_tmp_ids, corresponding_obj_ids
55
+
56
+ def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
57
+ # delete an object or a list of objects
58
+ # re-sort the tmp ids
59
+ if isinstance(obj_ids_to_remove, int):
60
+ obj_ids_to_remove = [obj_ids_to_remove]
61
+
62
+ new_tmp_id = 1
63
+ total_num_id = len(self.obj_to_tmp_id)
64
+
65
+ local_obj_to_tmp_id = {}
66
+ local_tmp_to_obj_id = {}
67
+
68
+ for tmp_iter in range(1, total_num_id + 1):
69
+ obj = self.tmp_id_to_obj[tmp_iter]
70
+ if obj.id not in obj_ids_to_remove:
71
+ local_obj_to_tmp_id[obj] = new_tmp_id
72
+ local_tmp_to_obj_id[new_tmp_id] = obj
73
+ new_tmp_id += 1
74
+
75
+ self.obj_to_tmp_id = local_obj_to_tmp_id
76
+ self.tmp_id_to_obj = local_tmp_to_obj_id
77
+ self._recompute_obj_id_to_obj_mapping()
78
+
79
+ def purge_inactive_objects(self,
80
+ max_missed_detection_count: int) -> (bool, List[int], List[int]):
81
+ # remove tmp ids of objects that are removed
82
+ obj_id_to_be_deleted = []
83
+ tmp_id_to_be_deleted = []
84
+ tmp_id_to_keep = []
85
+ obj_id_to_keep = []
86
+
87
+ for obj in self.obj_to_tmp_id:
88
+ if obj.poke_count > max_missed_detection_count:
89
+ obj_id_to_be_deleted.append(obj.id)
90
+ tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
91
+ else:
92
+ tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
93
+ obj_id_to_keep.append(obj.id)
94
+
95
+ purge_activated = len(obj_id_to_be_deleted) > 0
96
+ if purge_activated:
97
+ self.delete_objects(obj_id_to_be_deleted)
98
+ return purge_activated, tmp_id_to_keep, obj_id_to_keep
99
+
100
+ def tmp_to_obj_cls(self, mask) -> torch.Tensor:
101
+ # remap tmp id cls representation to the true object id representation
102
+ new_mask = torch.zeros_like(mask)
103
+ for tmp_id, obj in self.tmp_id_to_obj.items():
104
+ new_mask[mask == tmp_id] = obj.id
105
+ return new_mask
106
+
107
+ def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
108
+ # returns the mapping in a dict format for saving it with pickle
109
+ return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
110
+
111
+ def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
112
+ # turns a dict indexed by obj id into a tensor, ordered by tmp IDs
113
+ output = []
114
+ for _, obj in self.tmp_id_to_obj.items():
115
+ if obj.id not in obj_dict:
116
+ raise NotImplementedError
117
+ output.append(obj_dict[obj.id])
118
+ output = torch.stack(output, dim=dim)
119
+ return output
120
+
121
+ def make_one_hot(self, cls_mask) -> torch.Tensor:
122
+ output = []
123
+ for _, obj in self.tmp_id_to_obj.items():
124
+ output.append(cls_mask == obj.id)
125
+ if len(output) == 0:
126
+ output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
127
+ else:
128
+ output = torch.stack(output, dim=0)
129
+ return output
130
+
131
+ @property
132
+ def all_obj_ids(self) -> List[int]:
133
+ return [k.id for k in self.obj_to_tmp_id]
134
+
135
+ @property
136
+ def num_obj(self) -> int:
137
+ return len(self.obj_to_tmp_id)
138
+
139
+ def has_all(self, objects: List[int]) -> bool:
140
+ for obj in objects:
141
+ if obj not in self.obj_to_tmp_id:
142
+ return False
143
+ return True
144
+
145
+ def find_object_by_id(self, obj_id) -> ObjectInfo:
146
+ return self.obj_id_to_obj[obj_id]
147
+
148
+ def find_tmp_by_id(self, obj_id) -> int:
149
+ return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
hf_space/third_party/matanyone/inference/utils/__init__.py ADDED
File without changes
hf_space/third_party/matanyone/inference/utils/args_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+
4
+ log = logging.getLogger()
5
+
6
+
7
+ def get_dataset_cfg(cfg: DictConfig):
8
+ dataset_name = cfg.dataset
9
+ data_cfg = cfg.datasets[dataset_name]
10
+
11
+ potential_overrides = [
12
+ 'image_directory',
13
+ 'mask_directory',
14
+ 'json_directory',
15
+ 'size',
16
+ 'save_all',
17
+ 'use_all_masks',
18
+ 'use_long_term',
19
+ 'mem_every',
20
+ ]
21
+
22
+ for override in potential_overrides:
23
+ if cfg[override] is not None:
24
+ log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
25
+ data_cfg[override] = cfg[override]
26
+ # escalte all potential overrides to the top-level config
27
+ if override in data_cfg:
28
+ cfg[override] = data_cfg[override]
29
+
30
+ return data_cfg
hf_space/third_party/matanyone/model/__init__.py ADDED
File without changes
hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (195 Bytes). View file
 
hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc ADDED
Binary file (5.64 kB). View file
 
hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc ADDED
Binary file (19.8 kB). View file
 
hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc ADDED
Binary file (2.85 kB). View file
 
hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc ADDED
Binary file (7.43 kB). View file
 
hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc ADDED
Binary file (17.5 kB). View file
 
hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc ADDED
Binary file (11.8 kB). View file
 
hf_space/third_party/matanyone/model/aux_modules.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ For computing auxiliary outputs for auxiliary losses
3
+ """
4
+ from typing import Dict
5
+ from omegaconf import DictConfig
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from matanyone.model.group_modules import GConv2d
10
+ from matanyone.utils.tensor_utils import aggregate
11
+
12
+
13
+ class LinearPredictor(nn.Module):
14
+ def __init__(self, x_dim: int, pix_dim: int):
15
+ super().__init__()
16
+ self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
17
+
18
+ def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
19
+ # pixel_feat: B*pix_dim*H*W
20
+ # x: B*num_objects*x_dim*H*W
21
+ num_objects = x.shape[1]
22
+ x = self.projection(x)
23
+
24
+ pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
25
+ logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
26
+ return logits
27
+
28
+
29
+ class DirectPredictor(nn.Module):
30
+ def __init__(self, x_dim: int):
31
+ super().__init__()
32
+ self.projection = GConv2d(x_dim, 1, kernel_size=1)
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ # x: B*num_objects*x_dim*H*W
36
+ logits = self.projection(x).squeeze(2)
37
+ return logits
38
+
39
+
40
+ class AuxComputer(nn.Module):
41
+ def __init__(self, cfg: DictConfig):
42
+ super().__init__()
43
+
44
+ use_sensory_aux = cfg.model.aux_loss.sensory.enabled
45
+ self.use_query_aux = cfg.model.aux_loss.query.enabled
46
+ self.use_sensory_aux = use_sensory_aux
47
+
48
+ sensory_dim = cfg.model.sensory_dim
49
+ embed_dim = cfg.model.embed_dim
50
+
51
+ if use_sensory_aux:
52
+ self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
53
+
54
+ def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
55
+ prob = torch.sigmoid(logits)
56
+ if selector is not None:
57
+ prob = prob * selector
58
+ logits = aggregate(prob, dim=1)
59
+ return logits
60
+
61
+ def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
62
+ selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
63
+ sensory = aux_input['sensory']
64
+ q_logits = aux_input['q_logits']
65
+
66
+ aux_output = {}
67
+ aux_output['attn_mask'] = aux_input['attn_mask']
68
+
69
+ if self.use_sensory_aux:
70
+ # B*num_objects*H*W
71
+ logits = self.sensory_aux(pix_feat, sensory)
72
+ aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
73
+ if self.use_query_aux:
74
+ # B*num_objects*num_levels*H*W
75
+ aux_output['q_logits'] = self._aggregate_with_selector(
76
+ torch.stack(q_logits, dim=2),
77
+ selector.unsqueeze(2) if selector is not None else None)
78
+
79
+ return aux_output
80
+
81
+ def compute_mask(self, aux_input: Dict[str, torch.Tensor],
82
+ selector: torch.Tensor) -> Dict[str, torch.Tensor]:
83
+ # sensory = aux_input['sensory']
84
+ q_logits = aux_input['q_logits']
85
+
86
+ aux_output = {}
87
+
88
+ # B*num_objects*num_levels*H*W
89
+ aux_output['q_logits'] = self._aggregate_with_selector(
90
+ torch.stack(q_logits, dim=2),
91
+ selector.unsqueeze(2) if selector is not None else None)
92
+
93
+ return aux_output
hf_space/third_party/matanyone/model/big_modules.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ big_modules.py - This file stores higher-level network blocks.
3
+
4
+ x - usually denotes features that are shared between objects.
5
+ g - usually denotes features that are not shared between objects
6
+ with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
7
+
8
+ The trailing number of a variable usually denotes the stride
9
+ """
10
+
11
+ from typing import Iterable
12
+ from omegaconf import DictConfig
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
18
+ from matanyone.model.utils import resnet
19
+ from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
20
+
21
+ class UncertPred(nn.Module):
22
+ def __init__(self, model_cfg: DictConfig):
23
+ super().__init__()
24
+ self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(64)
26
+ self.relu = nn.ReLU(inplace=True)
27
+ self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
28
+ self.bn2 = nn.BatchNorm2d(32)
29
+ self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
30
+
31
+ def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
32
+ last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
33
+ x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
34
+ x = self.conv1x1_v2(x)
35
+ x = self.bn1(x)
36
+ x = self.relu(x)
37
+ x = self.conv3x3(x)
38
+ x = self.bn2(x)
39
+ x = self.relu(x)
40
+ x = self.conv3x3_out(x)
41
+ return x
42
+
43
+ # override the default train() to freeze BN statistics
44
+ def train(self, mode=True):
45
+ self.training = False
46
+ for module in self.children():
47
+ module.train(False)
48
+ return self
49
+
50
+ class PixelEncoder(nn.Module):
51
+ def __init__(self, model_cfg: DictConfig):
52
+ super().__init__()
53
+
54
+ self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
55
+ # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
56
+ # else default to True
57
+ is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
58
+ if self.is_resnet:
59
+ if model_cfg.pixel_encoder.type == 'resnet18':
60
+ network = resnet.resnet18(pretrained=is_pretrained_resnet)
61
+ elif model_cfg.pixel_encoder.type == 'resnet50':
62
+ network = resnet.resnet50(pretrained=is_pretrained_resnet)
63
+ else:
64
+ raise NotImplementedError
65
+ self.conv1 = network.conv1
66
+ self.bn1 = network.bn1
67
+ self.relu = network.relu
68
+ self.maxpool = network.maxpool
69
+
70
+ self.res2 = network.layer1
71
+ self.layer2 = network.layer2
72
+ self.layer3 = network.layer3
73
+ else:
74
+ raise NotImplementedError
75
+
76
+ def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
77
+ f1 = x
78
+ x = self.conv1(x)
79
+ x = self.bn1(x)
80
+ x = self.relu(x)
81
+ f2 = x
82
+ x = self.maxpool(x)
83
+ f4 = self.res2(x)
84
+ f8 = self.layer2(f4)
85
+ f16 = self.layer3(f8)
86
+
87
+ return f16, f8, f4, f2, f1
88
+
89
+ # override the default train() to freeze BN statistics
90
+ def train(self, mode=True):
91
+ self.training = False
92
+ for module in self.children():
93
+ module.train(False)
94
+ return self
95
+
96
+
97
+ class KeyProjection(nn.Module):
98
+ def __init__(self, model_cfg: DictConfig):
99
+ super().__init__()
100
+ in_dim = model_cfg.pixel_encoder.ms_dims[0]
101
+ mid_dim = model_cfg.pixel_dim
102
+ key_dim = model_cfg.key_dim
103
+
104
+ self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
105
+ self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
106
+ # shrinkage
107
+ self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
108
+ # selection
109
+ self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
110
+
111
+ nn.init.orthogonal_(self.key_proj.weight.data)
112
+ nn.init.zeros_(self.key_proj.bias.data)
113
+
114
+ def forward(self, x: torch.Tensor, *, need_s: bool,
115
+ need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
116
+ x = self.pix_feat_proj(x)
117
+ shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
118
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
119
+
120
+ return self.key_proj(x), shrinkage, selection
121
+
122
+
123
+ class MaskEncoder(nn.Module):
124
+ def __init__(self, model_cfg: DictConfig, single_object=False):
125
+ super().__init__()
126
+ pixel_dim = model_cfg.pixel_dim
127
+ value_dim = model_cfg.value_dim
128
+ sensory_dim = model_cfg.sensory_dim
129
+ final_dim = model_cfg.mask_encoder.final_dim
130
+
131
+ self.single_object = single_object
132
+ extra_dim = 1 if single_object else 2
133
+
134
+ # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
135
+ # else default to True
136
+ is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
137
+ if model_cfg.mask_encoder.type == 'resnet18':
138
+ network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
139
+ elif model_cfg.mask_encoder.type == 'resnet50':
140
+ network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
141
+ else:
142
+ raise NotImplementedError
143
+ self.conv1 = network.conv1
144
+ self.bn1 = network.bn1
145
+ self.relu = network.relu
146
+ self.maxpool = network.maxpool
147
+
148
+ self.layer1 = network.layer1
149
+ self.layer2 = network.layer2
150
+ self.layer3 = network.layer3
151
+
152
+ self.distributor = MainToGroupDistributor()
153
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
154
+
155
+ self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
156
+
157
+ def forward(self,
158
+ image: torch.Tensor,
159
+ pix_feat: torch.Tensor,
160
+ sensory: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ others: torch.Tensor,
163
+ *,
164
+ deep_update: bool = True,
165
+ chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
166
+ # ms_features are from the key encoder
167
+ # we only use the first one (lowest resolution), following XMem
168
+ if self.single_object:
169
+ g = masks.unsqueeze(2)
170
+ else:
171
+ g = torch.stack([masks, others], dim=2)
172
+
173
+ g = self.distributor(image, g)
174
+
175
+ batch_size, num_objects = g.shape[:2]
176
+ if chunk_size < 1 or chunk_size >= num_objects:
177
+ chunk_size = num_objects
178
+ fast_path = True
179
+ new_sensory = sensory
180
+ else:
181
+ if deep_update:
182
+ new_sensory = torch.empty_like(sensory)
183
+ else:
184
+ new_sensory = sensory
185
+ fast_path = False
186
+
187
+ # chunk-by-chunk inference
188
+ all_g = []
189
+ for i in range(0, num_objects, chunk_size):
190
+ if fast_path:
191
+ g_chunk = g
192
+ else:
193
+ g_chunk = g[:, i:i + chunk_size]
194
+ actual_chunk_size = g_chunk.shape[1]
195
+ g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
196
+
197
+ g_chunk = self.conv1(g_chunk)
198
+ g_chunk = self.bn1(g_chunk) # 1/2, 64
199
+ g_chunk = self.maxpool(g_chunk) # 1/4, 64
200
+ g_chunk = self.relu(g_chunk)
201
+
202
+ g_chunk = self.layer1(g_chunk) # 1/4
203
+ g_chunk = self.layer2(g_chunk) # 1/8
204
+ g_chunk = self.layer3(g_chunk) # 1/16
205
+
206
+ g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
207
+ g_chunk = self.fuser(pix_feat, g_chunk)
208
+ all_g.append(g_chunk)
209
+ if deep_update:
210
+ if fast_path:
211
+ new_sensory = self.sensory_update(g_chunk, sensory)
212
+ else:
213
+ new_sensory[:, i:i + chunk_size] = self.sensory_update(
214
+ g_chunk, sensory[:, i:i + chunk_size])
215
+ g = torch.cat(all_g, dim=1)
216
+
217
+ return g, new_sensory
218
+
219
+ # override the default train() to freeze BN statistics
220
+ def train(self, mode=True):
221
+ self.training = False
222
+ for module in self.children():
223
+ module.train(False)
224
+ return self
225
+
226
+
227
+ class PixelFeatureFuser(nn.Module):
228
+ def __init__(self, model_cfg: DictConfig, single_object=False):
229
+ super().__init__()
230
+ value_dim = model_cfg.value_dim
231
+ sensory_dim = model_cfg.sensory_dim
232
+ pixel_dim = model_cfg.pixel_dim
233
+ embed_dim = model_cfg.embed_dim
234
+ self.single_object = single_object
235
+
236
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
237
+ if self.single_object:
238
+ self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
239
+ else:
240
+ self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
241
+
242
+ def forward(self,
243
+ pix_feat: torch.Tensor,
244
+ pixel_memory: torch.Tensor,
245
+ sensory_memory: torch.Tensor,
246
+ last_mask: torch.Tensor,
247
+ last_others: torch.Tensor,
248
+ *,
249
+ chunk_size: int = -1) -> torch.Tensor:
250
+ batch_size, num_objects = pixel_memory.shape[:2]
251
+
252
+ if self.single_object:
253
+ last_mask = last_mask.unsqueeze(2)
254
+ else:
255
+ last_mask = torch.stack([last_mask, last_others], dim=2)
256
+
257
+ if chunk_size < 1:
258
+ chunk_size = num_objects
259
+
260
+ # chunk-by-chunk inference
261
+ all_p16 = []
262
+ for i in range(0, num_objects, chunk_size):
263
+ sensory_readout = self.sensory_compress(
264
+ torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
265
+ p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
266
+ p16 = self.fuser(pix_feat, p16)
267
+ all_p16.append(p16)
268
+ p16 = torch.cat(all_p16, dim=1)
269
+
270
+ return p16
271
+
272
+
273
+ class MaskDecoder(nn.Module):
274
+ def __init__(self, model_cfg: DictConfig):
275
+ super().__init__()
276
+ embed_dim = model_cfg.embed_dim
277
+ sensory_dim = model_cfg.sensory_dim
278
+ ms_image_dims = model_cfg.pixel_encoder.ms_dims
279
+ up_dims = model_cfg.mask_decoder.up_dims
280
+
281
+ assert embed_dim == up_dims[0]
282
+
283
+ self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
284
+ sensory_dim)
285
+
286
+ self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
287
+ self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
288
+ self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
289
+ # newly add for alpha matte
290
+ self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
291
+ self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
292
+
293
+ self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
294
+ self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
295
+
296
+ def forward(self,
297
+ ms_image_feat: Iterable[torch.Tensor],
298
+ memory_readout: torch.Tensor,
299
+ sensory: torch.Tensor,
300
+ *,
301
+ chunk_size: int = -1,
302
+ update_sensory: bool = True,
303
+ seg_pass: bool = False,
304
+ last_mask=None,
305
+ sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
306
+
307
+ batch_size, num_objects = memory_readout.shape[:2]
308
+ f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
309
+ if chunk_size < 1 or chunk_size >= num_objects:
310
+ chunk_size = num_objects
311
+ fast_path = True
312
+ new_sensory = sensory
313
+ else:
314
+ if update_sensory:
315
+ new_sensory = torch.empty_like(sensory)
316
+ else:
317
+ new_sensory = sensory
318
+ fast_path = False
319
+
320
+ # chunk-by-chunk inference
321
+ all_logits = []
322
+ for i in range(0, num_objects, chunk_size):
323
+ if fast_path:
324
+ p16 = memory_readout
325
+ else:
326
+ p16 = memory_readout[:, i:i + chunk_size]
327
+ actual_chunk_size = p16.shape[1]
328
+
329
+ p8 = self.up_16_8(p16, f8)
330
+ p4 = self.up_8_4(p8, f4)
331
+ p2 = self.up_4_2(p4, f2)
332
+ p1 = self.up_2_1(p2, f1)
333
+ with torch.amp.autocast("cuda",enabled=False):
334
+ if seg_pass:
335
+ if last_mask is not None:
336
+ res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
337
+ if sigmoid_residual:
338
+ res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
339
+ logits = last_mask + res
340
+ else:
341
+ logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
342
+ else:
343
+ if last_mask is not None:
344
+ res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
345
+ if sigmoid_residual:
346
+ res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
347
+ logits = last_mask + res
348
+ else:
349
+ logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
350
+ ## SensoryUpdater_fullscale
351
+ if update_sensory:
352
+ p1 = torch.cat(
353
+ [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
354
+ if fast_path:
355
+ new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
356
+ else:
357
+ new_sensory[:,
358
+ i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
359
+ sensory[:,
360
+ i:i + chunk_size])
361
+ all_logits.append(logits)
362
+ logits = torch.cat(all_logits, dim=0)
363
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
364
+
365
+ return new_sensory, logits
hf_space/third_party/matanyone/model/channel_attn.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class CAResBlock(nn.Module):
8
+ def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
9
+ super().__init__()
10
+ self.residual = residual
11
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
12
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
13
+
14
+ t = int((abs(math.log2(out_dim)) + 1) // 2)
15
+ k = t if t % 2 else t + 1
16
+ self.pool = nn.AdaptiveAvgPool2d(1)
17
+ self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
18
+
19
+ if self.residual:
20
+ if in_dim == out_dim:
21
+ self.downsample = nn.Identity()
22
+ else:
23
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ r = x
27
+ x = self.conv1(F.relu(x))
28
+ x = self.conv2(F.relu(x))
29
+
30
+ b, c = x.shape[:2]
31
+ w = self.pool(x).view(b, 1, c)
32
+ w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
33
+
34
+ if self.residual:
35
+ x = x * w + self.downsample(r)
36
+ else:
37
+ x = x * w
38
+
39
+ return x
hf_space/third_party/matanyone/model/group_modules.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from matanyone.model.channel_attn import CAResBlock
6
+
7
+ def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
8
+ align_corners: bool) -> torch.Tensor:
9
+ batch_size, num_objects = g.shape[:2]
10
+ g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
11
+ scale_factor=ratio,
12
+ mode=mode,
13
+ align_corners=align_corners)
14
+ g = g.view(batch_size, num_objects, *g.shape[1:])
15
+ return g
16
+
17
+
18
+ def upsample_groups(g: torch.Tensor,
19
+ ratio: float = 2,
20
+ mode: str = 'bilinear',
21
+ align_corners: bool = False) -> torch.Tensor:
22
+ return interpolate_groups(g, ratio, mode, align_corners)
23
+
24
+
25
+ def downsample_groups(g: torch.Tensor,
26
+ ratio: float = 1 / 2,
27
+ mode: str = 'area',
28
+ align_corners: bool = None) -> torch.Tensor:
29
+ return interpolate_groups(g, ratio, mode, align_corners)
30
+
31
+
32
+ class GConv2d(nn.Conv2d):
33
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
34
+ batch_size, num_objects = g.shape[:2]
35
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
36
+ return g.view(batch_size, num_objects, *g.shape[1:])
37
+
38
+
39
+ class GroupResBlock(nn.Module):
40
+ def __init__(self, in_dim: int, out_dim: int):
41
+ super().__init__()
42
+
43
+ if in_dim == out_dim:
44
+ self.downsample = nn.Identity()
45
+ else:
46
+ self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
47
+
48
+ self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
49
+ self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
50
+
51
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
52
+ out_g = self.conv1(F.relu(g))
53
+ out_g = self.conv2(F.relu(out_g))
54
+
55
+ g = self.downsample(g)
56
+
57
+ return out_g + g
58
+
59
+
60
+ class MainToGroupDistributor(nn.Module):
61
+ def __init__(self,
62
+ x_transform: Optional[nn.Module] = None,
63
+ g_transform: Optional[nn.Module] = None,
64
+ method: str = 'cat',
65
+ reverse_order: bool = False):
66
+ super().__init__()
67
+
68
+ self.x_transform = x_transform
69
+ self.g_transform = g_transform
70
+ self.method = method
71
+ self.reverse_order = reverse_order
72
+
73
+ def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
74
+ num_objects = g.shape[1]
75
+
76
+ if self.x_transform is not None:
77
+ x = self.x_transform(x)
78
+
79
+ if self.g_transform is not None:
80
+ g = self.g_transform(g)
81
+
82
+ if not skip_expand:
83
+ x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
84
+ if self.method == 'cat':
85
+ if self.reverse_order:
86
+ g = torch.cat([g, x], 2)
87
+ else:
88
+ g = torch.cat([x, g], 2)
89
+ elif self.method == 'add':
90
+ g = x + g
91
+ elif self.method == 'mulcat':
92
+ g = torch.cat([x * g, g], dim=2)
93
+ elif self.method == 'muladd':
94
+ g = x * g + g
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return g
99
+
100
+
101
+ class GroupFeatureFusionBlock(nn.Module):
102
+ def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
103
+ super().__init__()
104
+
105
+ x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
106
+ g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
107
+
108
+ self.distributor = MainToGroupDistributor(x_transform=x_transform,
109
+ g_transform=g_transform,
110
+ method='add')
111
+ self.block1 = CAResBlock(out_dim, out_dim)
112
+ self.block2 = CAResBlock(out_dim, out_dim)
113
+
114
+ def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
115
+ batch_size, num_objects = g.shape[:2]
116
+
117
+ g = self.distributor(x, g)
118
+
119
+ g = g.flatten(start_dim=0, end_dim=1)
120
+
121
+ g = self.block1(g)
122
+ g = self.block2(g)
123
+
124
+ g = g.view(batch_size, num_objects, *g.shape[1:])
125
+
126
+ return g
hf_space/third_party/matanyone/model/matanyone.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Iterable, Tuple
2
+ import logging
3
+ from omegaconf import DictConfig
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from omegaconf import OmegaConf
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
11
+ from matanyone.model.aux_modules import AuxComputer
12
+ from matanyone.model.utils.memory_utils import get_affinity, readout
13
+ from matanyone.model.transformer.object_transformer import QueryTransformer
14
+ from matanyone.model.transformer.object_summarizer import ObjectSummarizer
15
+ from matanyone.utils.tensor_utils import aggregate
16
+
17
+ log = logging.getLogger()
18
+ class MatAnyone(nn.Module,
19
+ PyTorchModelHubMixin,
20
+ library_name="matanyone",
21
+ repo_url="https://github.com/pq-yang/MatAnyone",
22
+ coders={
23
+ DictConfig: (
24
+ lambda x: OmegaConf.to_container(x),
25
+ lambda data: OmegaConf.create(data),
26
+ )
27
+ },
28
+ ):
29
+
30
+ def __init__(self, cfg: DictConfig, *, single_object=False):
31
+ super().__init__()
32
+ self.cfg = cfg
33
+ model_cfg = cfg.model
34
+ self.ms_dims = model_cfg.pixel_encoder.ms_dims
35
+ self.key_dim = model_cfg.key_dim
36
+ self.value_dim = model_cfg.value_dim
37
+ self.sensory_dim = model_cfg.sensory_dim
38
+ self.pixel_dim = model_cfg.pixel_dim
39
+ self.embed_dim = model_cfg.embed_dim
40
+ self.single_object = single_object
41
+
42
+ log.info(f'Single object: {self.single_object}')
43
+
44
+ self.pixel_encoder = PixelEncoder(model_cfg)
45
+ self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
46
+ self.key_proj = KeyProjection(model_cfg)
47
+ self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
48
+ self.mask_decoder = MaskDecoder(model_cfg)
49
+ self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
50
+ self.object_transformer = QueryTransformer(model_cfg)
51
+ self.object_summarizer = ObjectSummarizer(model_cfg)
52
+ self.aux_computer = AuxComputer(cfg)
53
+ self.temp_sparity = UncertPred(model_cfg)
54
+
55
+ self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
56
+ self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
57
+
58
+ def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
59
+ # for each object, return the sum of masks of all other objects
60
+ if self.single_object:
61
+ return None
62
+
63
+ num_objects = masks.shape[1]
64
+ if num_objects >= 1:
65
+ others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
66
+ else:
67
+ others = torch.zeros_like(masks)
68
+ return others
69
+
70
+ def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
71
+ logits = self.temp_sparity(last_frame_feat=last_pix_feat,
72
+ cur_frame_feat=cur_pix_feat,
73
+ last_mask=last_mask,
74
+ mem_val_diff=mem_val_diff)
75
+
76
+ prob = torch.sigmoid(logits)
77
+ mask = (prob > 0) + 0
78
+
79
+ uncert_output = {"logits": logits,
80
+ "prob": prob,
81
+ "mask": mask}
82
+
83
+ return uncert_output
84
+
85
+ def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
86
+ image = (image - self.pixel_mean) / self.pixel_std
87
+ ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
88
+ return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
89
+
90
+ def encode_mask(
91
+ self,
92
+ image: torch.Tensor,
93
+ ms_features: List[torch.Tensor],
94
+ sensory: torch.Tensor,
95
+ masks: torch.Tensor,
96
+ *,
97
+ deep_update: bool = True,
98
+ chunk_size: int = -1,
99
+ need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
100
+ image = (image - self.pixel_mean) / self.pixel_std
101
+ others = self._get_others(masks)
102
+ mask_value, new_sensory = self.mask_encoder(image,
103
+ ms_features,
104
+ sensory,
105
+ masks,
106
+ others,
107
+ deep_update=deep_update,
108
+ chunk_size=chunk_size)
109
+ object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
110
+ return mask_value, new_sensory, object_summaries, object_logits
111
+
112
+ def transform_key(self,
113
+ final_pix_feat: torch.Tensor,
114
+ *,
115
+ need_sk: bool = True,
116
+ need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
117
+ key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
118
+ return key, shrinkage, selection
119
+
120
+ # Used in training only.
121
+ # This step is replaced by MemoryManager in test time
122
+ def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
123
+ memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
124
+ msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
125
+ sensory: torch.Tensor, last_mask: torch.Tensor,
126
+ selector: torch.Tensor, uncert_output=None, seg_pass=False,
127
+ last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
128
+ """
129
+ query_key : B * CK * H * W
130
+ query_selection : B * CK * H * W
131
+ memory_key : B * CK * T * H * W
132
+ memory_shrinkage: B * 1 * T * H * W
133
+ msk_value : B * num_objects * CV * T * H * W
134
+ obj_memory : B * num_objects * T * num_summaries * C
135
+ pixel_feature : B * C * H * W
136
+ """
137
+ batch_size, num_objects = msk_value.shape[:2]
138
+
139
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
140
+
141
+ # read using visual attention
142
+ with torch.amp.autocast("cuda",enabled=False):
143
+ affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
144
+ query_selection.float(), uncert_mask=uncert_mask)
145
+
146
+ msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
147
+
148
+ # B * (num_objects*CV) * H * W
149
+ pixel_readout = readout(affinity, msk_value, uncert_mask)
150
+ pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
151
+ *pixel_readout.shape[-2:])
152
+
153
+ uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
154
+ uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
155
+ pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
156
+
157
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
158
+
159
+
160
+ # read from query transformer
161
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
162
+
163
+ aux_output = {
164
+ 'sensory': sensory,
165
+ 'q_logits': aux_features['logits'] if aux_features else None,
166
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
167
+ }
168
+
169
+ return mem_readout, aux_output, uncert_output
170
+
171
+ def read_first_frame_memory(self, pixel_readout,
172
+ obj_memory: torch.Tensor, pix_feat: torch.Tensor,
173
+ sensory: torch.Tensor, last_mask: torch.Tensor,
174
+ selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
175
+ """
176
+ query_key : B * CK * H * W
177
+ query_selection : B * CK * H * W
178
+ memory_key : B * CK * T * H * W
179
+ memory_shrinkage: B * 1 * T * H * W
180
+ msk_value : B * num_objects * CV * T * H * W
181
+ obj_memory : B * num_objects * T * num_summaries * C
182
+ pixel_feature : B * C * H * W
183
+ """
184
+
185
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
186
+
187
+ # read from query transformer
188
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
189
+
190
+ aux_output = {
191
+ 'sensory': sensory,
192
+ 'q_logits': aux_features['logits'] if aux_features else None,
193
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
194
+ }
195
+
196
+ return mem_readout, aux_output
197
+
198
+ def pixel_fusion(self,
199
+ pix_feat: torch.Tensor,
200
+ pixel: torch.Tensor,
201
+ sensory: torch.Tensor,
202
+ last_mask: torch.Tensor,
203
+ *,
204
+ chunk_size: int = -1) -> torch.Tensor:
205
+ last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
206
+ last_others = self._get_others(last_mask)
207
+ fused = self.pixel_fuser(pix_feat,
208
+ pixel,
209
+ sensory,
210
+ last_mask,
211
+ last_others,
212
+ chunk_size=chunk_size)
213
+ return fused
214
+
215
+ def readout_query(self,
216
+ pixel_readout,
217
+ obj_memory,
218
+ *,
219
+ selector=None,
220
+ need_weights=False,
221
+ seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
222
+ return self.object_transformer(pixel_readout,
223
+ obj_memory,
224
+ selector=selector,
225
+ need_weights=need_weights,
226
+ seg_pass=seg_pass)
227
+
228
+ def segment(self,
229
+ ms_image_feat: List[torch.Tensor],
230
+ memory_readout: torch.Tensor,
231
+ sensory: torch.Tensor,
232
+ *,
233
+ selector: bool = None,
234
+ chunk_size: int = -1,
235
+ update_sensory: bool = True,
236
+ seg_pass: bool = False,
237
+ clamp_mat: bool = True,
238
+ last_mask=None,
239
+ sigmoid_residual=False,
240
+ seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
241
+ """
242
+ multi_scale_features is from the key encoder for skip-connection
243
+ memory_readout is from working/long-term memory
244
+ sensory is the sensory memory
245
+ last_mask is the mask from the last frame, supplementing sensory memory
246
+ selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
247
+ during training.
248
+ """
249
+ #### use mat head for seg data
250
+ if seg_mat:
251
+ assert seg_pass
252
+ seg_pass = False
253
+ ####
254
+ sensory, logits = self.mask_decoder(ms_image_feat,
255
+ memory_readout,
256
+ sensory,
257
+ chunk_size=chunk_size,
258
+ update_sensory=update_sensory,
259
+ seg_pass = seg_pass,
260
+ last_mask=last_mask,
261
+ sigmoid_residual=sigmoid_residual)
262
+ if seg_pass:
263
+ prob = torch.sigmoid(logits)
264
+ if selector is not None:
265
+ prob = prob * selector
266
+
267
+ # Softmax over all objects[]
268
+ logits = aggregate(prob, dim=1)
269
+ prob = F.softmax(logits, dim=1)
270
+ else:
271
+ if clamp_mat:
272
+ logits = logits.clamp(0.0, 1.0)
273
+ logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
274
+ prob = logits
275
+
276
+ return sensory, logits, prob
277
+
278
+ def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
279
+ selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
280
+ return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
281
+
282
+ def forward(self, *args, **kwargs):
283
+ raise NotImplementedError
284
+
285
+ def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
286
+ if not self.single_object:
287
+ # Map single-object weight to multi-object weight (4->5 out channels in conv1)
288
+ for k in list(src_dict.keys()):
289
+ if k == 'mask_encoder.conv1.weight':
290
+ if src_dict[k].shape[1] == 4:
291
+ log.info(f'Converting {k} from single object to multiple objects.')
292
+ pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
293
+ if not init_as_zero_if_needed:
294
+ nn.init.orthogonal_(pads)
295
+ log.info(f'Randomly initialized padding for {k}.')
296
+ else:
297
+ log.info(f'Zero-initialized padding for {k}.')
298
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
299
+ elif k == 'pixel_fuser.sensory_compress.weight':
300
+ if src_dict[k].shape[1] == self.sensory_dim + 1:
301
+ log.info(f'Converting {k} from single object to multiple objects.')
302
+ pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
303
+ if not init_as_zero_if_needed:
304
+ nn.init.orthogonal_(pads)
305
+ log.info(f'Randomly initialized padding for {k}.')
306
+ else:
307
+ log.info(f'Zero-initialized padding for {k}.')
308
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
309
+ elif self.single_object:
310
+ """
311
+ If the model is multiple-object and we are training in single-object,
312
+ we strip the last channel of conv1.
313
+ This is not supposed to happen in standard training except when users are trying to
314
+ finetune a trained model with single object datasets.
315
+ """
316
+ if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
317
+ log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
318
+ 'This is not supposed to happen in standard training.')
319
+ src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
320
+ src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
321
+
322
+ for k in src_dict:
323
+ if k not in self.state_dict():
324
+ log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
325
+ for k in self.state_dict():
326
+ if k not in src_dict:
327
+ log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
328
+
329
+ self.load_state_dict(src_dict, strict=False)
330
+
331
+ @property
332
+ def device(self) -> torch.device:
333
+ return self.pixel_mean.device
hf_space/third_party/matanyone/model/modules.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
7
+
8
+
9
+ class UpsampleBlock(nn.Module):
10
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
11
+ super().__init__()
12
+ self.out_conv = ResBlock(in_dim, out_dim)
13
+ self.scale_factor = scale_factor
14
+
15
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
16
+ g = F.interpolate(in_g,
17
+ scale_factor=self.scale_factor,
18
+ mode='bilinear')
19
+ g = self.out_conv(g)
20
+ g = g + skip_f
21
+ return g
22
+
23
+ class MaskUpsampleBlock(nn.Module):
24
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
25
+ super().__init__()
26
+ self.distributor = MainToGroupDistributor(method='add')
27
+ self.out_conv = GroupResBlock(in_dim, out_dim)
28
+ self.scale_factor = scale_factor
29
+
30
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
31
+ g = upsample_groups(in_g, ratio=self.scale_factor)
32
+ g = self.distributor(skip_f, g)
33
+ g = self.out_conv(g)
34
+ return g
35
+
36
+
37
+ class DecoderFeatureProcessor(nn.Module):
38
+ def __init__(self, decoder_dims: List[int], out_dims: List[int]):
39
+ super().__init__()
40
+ self.transforms = nn.ModuleList([
41
+ nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
42
+ ])
43
+
44
+ def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
45
+ outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
46
+ return outputs
47
+
48
+
49
+ # @torch.jit.script
50
+ def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
51
+ # h: batch_size * num_objects * hidden_dim * h * w
52
+ # values: batch_size * num_objects * (hidden_dim*3) * h * w
53
+ dim = values.shape[2] // 3
54
+ forget_gate = torch.sigmoid(values[:, :, :dim])
55
+ update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
56
+ new_value = torch.tanh(values[:, :, dim * 2:])
57
+ new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
58
+ return new_h
59
+
60
+
61
+ class SensoryUpdater_fullscale(nn.Module):
62
+ # Used in the decoder, multi-scale feature + GRU
63
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
64
+ super().__init__()
65
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
66
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
67
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
68
+ self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
69
+ self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
70
+
71
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
72
+
73
+ nn.init.xavier_normal_(self.transform.weight)
74
+
75
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
76
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
77
+ self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
78
+ self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
79
+ self.g1_conv(downsample_groups(g[4], ratio=1/16))
80
+
81
+ with torch.amp.autocast("cuda",enabled=False):
82
+ g = g.float()
83
+ h = h.float()
84
+ values = self.transform(torch.cat([g, h], dim=2))
85
+ new_h = _recurrent_update(h, values)
86
+
87
+ return new_h
88
+
89
+ class SensoryUpdater(nn.Module):
90
+ # Used in the decoder, multi-scale feature + GRU
91
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
92
+ super().__init__()
93
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
94
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
95
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
96
+
97
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
98
+
99
+ nn.init.xavier_normal_(self.transform.weight)
100
+
101
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
102
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
103
+ self.g4_conv(downsample_groups(g[2], ratio=1/4))
104
+
105
+ with torch.amp.autocast("cuda",enabled=False):
106
+ g = g.float()
107
+ h = h.float()
108
+ values = self.transform(torch.cat([g, h], dim=2))
109
+ new_h = _recurrent_update(h, values)
110
+
111
+ return new_h
112
+
113
+
114
+ class SensoryDeepUpdater(nn.Module):
115
+ def __init__(self, f_dim: int, sensory_dim: int):
116
+ super().__init__()
117
+ self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
118
+
119
+ nn.init.xavier_normal_(self.transform.weight)
120
+
121
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
122
+ with torch.amp.autocast("cuda",enabled=False):
123
+ g = g.float()
124
+ h = h.float()
125
+ values = self.transform(torch.cat([g, h], dim=2))
126
+ new_h = _recurrent_update(h, values)
127
+
128
+ return new_h
129
+
130
+
131
+ class ResBlock(nn.Module):
132
+ def __init__(self, in_dim: int, out_dim: int):
133
+ super().__init__()
134
+
135
+ if in_dim == out_dim:
136
+ self.downsample = nn.Identity()
137
+ else:
138
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
139
+
140
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
141
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
142
+
143
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
144
+ out_g = self.conv1(F.relu(g))
145
+ out_g = self.conv2(F.relu(out_g))
146
+
147
+ g = self.downsample(g)
148
+
149
+ return out_g + g
hf_space/third_party/matanyone/model/transformer/__init__.py ADDED
File without changes
hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (207 Bytes). View file
 
hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc ADDED
Binary file (5.11 kB). View file
 
hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc ADDED
Binary file (12.1 kB). View file
 
hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc ADDED
Binary file (5.97 kB). View file
 
hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc ADDED
Binary file (8.94 kB). View file
 
hf_space/third_party/matanyone/model/transformer/object_summarizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from omegaconf import DictConfig
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from matanyone.model.transformer.positional_encoding import PositionalEncoding
8
+
9
+
10
+ # @torch.jit.script
11
+ def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
12
+ logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
13
+ # value: B*num_objects*H*W*value_dim
14
+ # logits: B*num_objects*H*W*num_summaries
15
+ # masks: B*num_objects*H*W*num_summaries: 1 if allowed
16
+ weights = logits.sigmoid() * masks
17
+ # B*num_objects*num_summaries*value_dim
18
+ sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
19
+ # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
20
+ area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
21
+
22
+ # B*num_objects*num_summaries*value_dim
23
+ return sums, area
24
+
25
+
26
+ class ObjectSummarizer(nn.Module):
27
+ def __init__(self, model_cfg: DictConfig):
28
+ super().__init__()
29
+
30
+ this_cfg = model_cfg.object_summarizer
31
+ self.value_dim = model_cfg.value_dim
32
+ self.embed_dim = this_cfg.embed_dim
33
+ self.num_summaries = this_cfg.num_summaries
34
+ self.add_pe = this_cfg.add_pe
35
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
36
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
37
+
38
+ if self.add_pe:
39
+ self.pos_enc = PositionalEncoding(self.embed_dim,
40
+ scale=self.pixel_pe_scale,
41
+ temperature=self.pixel_pe_temperature)
42
+
43
+ self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
44
+ self.feature_pred = nn.Sequential(
45
+ nn.Linear(self.embed_dim, self.embed_dim),
46
+ nn.ReLU(inplace=True),
47
+ nn.Linear(self.embed_dim, self.embed_dim),
48
+ )
49
+ self.weights_pred = nn.Sequential(
50
+ nn.Linear(self.embed_dim, self.embed_dim),
51
+ nn.ReLU(inplace=True),
52
+ nn.Linear(self.embed_dim, self.num_summaries),
53
+ )
54
+
55
+ def forward(self,
56
+ masks: torch.Tensor,
57
+ value: torch.Tensor,
58
+ need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
59
+ # masks: B*num_objects*(H0)*(W0)
60
+ # value: B*num_objects*value_dim*H*W
61
+ # -> B*num_objects*H*W*value_dim
62
+ h, w = value.shape[-2:]
63
+ masks = F.interpolate(masks, size=(h, w), mode='area')
64
+ masks = masks.unsqueeze(-1)
65
+ inv_masks = 1 - masks
66
+ repeated_masks = torch.cat([
67
+ masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
68
+ inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
69
+ ],
70
+ dim=-1)
71
+
72
+ value = value.permute(0, 1, 3, 4, 2)
73
+ value = self.input_proj(value)
74
+ if self.add_pe:
75
+ pe = self.pos_enc(value)
76
+ value = value + pe
77
+
78
+ with torch.amp.autocast("cuda",enabled=False):
79
+ value = value.float()
80
+ feature = self.feature_pred(value)
81
+ logits = self.weights_pred(value)
82
+ sums, area = _weighted_pooling(repeated_masks, feature, logits)
83
+
84
+ summaries = torch.cat([sums, area], dim=-1)
85
+
86
+ if need_weights:
87
+ return summaries, logits
88
+ else:
89
+ return summaries, None
hf_space/third_party/matanyone/model/transformer/object_transformer.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ from omegaconf import DictConfig
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from matanyone.model.group_modules import GConv2d
7
+ from matanyone.utils.tensor_utils import aggregate
8
+ from matanyone.model.transformer.positional_encoding import PositionalEncoding
9
+ from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
10
+
11
+
12
+ class QueryTransformerBlock(nn.Module):
13
+ def __init__(self, model_cfg: DictConfig):
14
+ super().__init__()
15
+
16
+ this_cfg = model_cfg.object_transformer
17
+ self.embed_dim = this_cfg.embed_dim
18
+ self.num_heads = this_cfg.num_heads
19
+ self.num_queries = this_cfg.num_queries
20
+ self.ff_dim = this_cfg.ff_dim
21
+
22
+ self.read_from_pixel = CrossAttention(self.embed_dim,
23
+ self.num_heads,
24
+ add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
25
+ self.self_attn = SelfAttention(self.embed_dim,
26
+ self.num_heads,
27
+ add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
28
+ self.ffn = FFN(self.embed_dim, self.ff_dim)
29
+ self.read_from_query = CrossAttention(self.embed_dim,
30
+ self.num_heads,
31
+ add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
32
+ norm=this_cfg.read_from_query.output_norm)
33
+ self.pixel_ffn = PixelFFN(self.embed_dim)
34
+
35
+ def forward(
36
+ self,
37
+ x: torch.Tensor,
38
+ pixel: torch.Tensor,
39
+ query_pe: torch.Tensor,
40
+ pixel_pe: torch.Tensor,
41
+ attn_mask: torch.Tensor,
42
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
43
+ # x: (bs*num_objects)*num_queries*embed_dim
44
+ # pixel: bs*num_objects*C*H*W
45
+ # query_pe: (bs*num_objects)*num_queries*embed_dim
46
+ # pixel_pe: (bs*num_objects)*(H*W)*C
47
+ # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
48
+
49
+ # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
50
+ pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
51
+ x, q_weights = self.read_from_pixel(x,
52
+ pixel_flat,
53
+ query_pe,
54
+ pixel_pe,
55
+ attn_mask=attn_mask,
56
+ need_weights=need_weights)
57
+ x = self.self_attn(x, query_pe)
58
+ x = self.ffn(x)
59
+
60
+ pixel_flat, p_weights = self.read_from_query(pixel_flat,
61
+ x,
62
+ pixel_pe,
63
+ query_pe,
64
+ need_weights=need_weights)
65
+ pixel = self.pixel_ffn(pixel, pixel_flat)
66
+
67
+ if need_weights:
68
+ bs, num_objects, _, h, w = pixel.shape
69
+ q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
70
+ p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
71
+ self.num_queries, h, w)
72
+
73
+ return x, pixel, q_weights, p_weights
74
+
75
+
76
+ class QueryTransformer(nn.Module):
77
+ def __init__(self, model_cfg: DictConfig):
78
+ super().__init__()
79
+
80
+ this_cfg = model_cfg.object_transformer
81
+ self.value_dim = model_cfg.value_dim
82
+ self.embed_dim = this_cfg.embed_dim
83
+ self.num_heads = this_cfg.num_heads
84
+ self.num_queries = this_cfg.num_queries
85
+
86
+ # query initialization and embedding
87
+ self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
88
+ self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
89
+
90
+ # projection from object summaries to query initialization and embedding
91
+ self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
92
+ self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
93
+
94
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
95
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
96
+ self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
97
+ self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
98
+ self.spatial_pe = PositionalEncoding(self.embed_dim,
99
+ scale=self.pixel_pe_scale,
100
+ temperature=self.pixel_pe_temperature,
101
+ channel_last=False,
102
+ transpose_output=True)
103
+
104
+ # transformer blocks
105
+ self.num_blocks = this_cfg.num_blocks
106
+ self.blocks = nn.ModuleList(
107
+ QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
108
+ self.mask_pred = nn.ModuleList(
109
+ nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
110
+ for _ in range(self.num_blocks + 1))
111
+
112
+ self.act = nn.ReLU(inplace=True)
113
+
114
+ def forward(self,
115
+ pixel: torch.Tensor,
116
+ obj_summaries: torch.Tensor,
117
+ selector: Optional[torch.Tensor] = None,
118
+ need_weights: bool = False,
119
+ seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
120
+ # pixel: B*num_objects*embed_dim*H*W
121
+ # obj_summaries: B*num_objects*T*num_queries*embed_dim
122
+ T = obj_summaries.shape[2]
123
+ bs, num_objects, _, H, W = pixel.shape
124
+
125
+ # normalize object values
126
+ # the last channel is the cumulative area of the object
127
+ obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
128
+ self.embed_dim + 1)
129
+ # sum over time
130
+ # during inference, T=1 as we already did streaming average in memory_manager
131
+ obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
132
+ obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
133
+ obj_values = obj_sums / (obj_area + 1e-4)
134
+ obj_init = self.summary_to_query_init(obj_values)
135
+ obj_emb = self.summary_to_query_emb(obj_values)
136
+
137
+ # positional embeddings for object queries
138
+ query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
139
+ query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
140
+
141
+ # positional embeddings for pixel features
142
+ pixel_init = self.pixel_init_proj(pixel)
143
+ pixel_emb = self.pixel_emb_proj(pixel)
144
+ pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
145
+ pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
146
+ pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
147
+
148
+ pixel = pixel_init
149
+
150
+ # run the transformer
151
+ aux_features = {'logits': []}
152
+
153
+ # first aux output
154
+ aux_logits = self.mask_pred[0](pixel).squeeze(2)
155
+ attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
156
+ aux_features['logits'].append(aux_logits)
157
+ for i in range(self.num_blocks):
158
+ query, pixel, q_weights, p_weights = self.blocks[i](query,
159
+ pixel,
160
+ query_emb,
161
+ pixel_pe,
162
+ attn_mask,
163
+ need_weights=need_weights)
164
+
165
+ if self.training or i <= self.num_blocks - 1 or need_weights:
166
+ aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
167
+ attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
168
+ aux_features['logits'].append(aux_logits)
169
+
170
+ aux_features['q_weights'] = q_weights # last layer only
171
+ aux_features['p_weights'] = p_weights # last layer only
172
+
173
+ if self.training:
174
+ # no need to save all heads
175
+ aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
176
+ self.num_queries, H, W)[:, :, 0]
177
+
178
+ return pixel, aux_features
179
+
180
+ def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
181
+ # logits: batch_size*num_objects*H*W
182
+ # selector: batch_size*num_objects*1*1
183
+ # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
184
+ # where True means the attention is blocked
185
+
186
+ if selector is None:
187
+ prob = logits.sigmoid()
188
+ else:
189
+ prob = logits.sigmoid() * selector
190
+ logits = aggregate(prob, dim=1)
191
+
192
+ is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
193
+ foreground_mask = is_foreground.bool().flatten(start_dim=2)
194
+ inv_foreground_mask = ~foreground_mask
195
+ inv_background_mask = foreground_mask
196
+
197
+ aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
198
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
199
+ aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
200
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
201
+
202
+ aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
203
+
204
+ aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
205
+
206
+ return aux_mask
hf_space/third_party/matanyone/model/transformer/positional_encoding.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference:
2
+ # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
3
+ # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Gets a base embedding for one dimension with sin and cos intertwined
15
+ """
16
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
17
+ return torch.flatten(emb, -2, -1)
18
+
19
+
20
+ class PositionalEncoding(nn.Module):
21
+ def __init__(self,
22
+ dim: int,
23
+ scale: float = math.pi * 2,
24
+ temperature: float = 10000,
25
+ normalize: bool = True,
26
+ channel_last: bool = True,
27
+ transpose_output: bool = False):
28
+ super().__init__()
29
+ dim = int(np.ceil(dim / 4) * 2)
30
+ self.dim = dim
31
+ inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
32
+ self.register_buffer("inv_freq", inv_freq)
33
+ self.normalize = normalize
34
+ self.scale = scale
35
+ self.eps = 1e-6
36
+ self.channel_last = channel_last
37
+ self.transpose_output = transpose_output
38
+
39
+ self.cached_penc = None # the cache is irrespective of the number of objects
40
+
41
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ :param tensor: A 4/5d tensor of size
44
+ channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
45
+ channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
46
+ :return: positional encoding tensor that has the same shape as the input if the input is 4d
47
+ if the input is 5d, the output is broadcastable along the k-dimension
48
+ """
49
+ if len(tensor.shape) != 4 and len(tensor.shape) != 5:
50
+ raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
51
+
52
+ if len(tensor.shape) == 5:
53
+ # take a sample from the k dimension
54
+ num_objects = tensor.shape[1]
55
+ tensor = tensor[:, 0]
56
+ else:
57
+ num_objects = None
58
+
59
+ if self.channel_last:
60
+ batch_size, h, w, c = tensor.shape
61
+ else:
62
+ batch_size, c, h, w = tensor.shape
63
+
64
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
65
+ if num_objects is None:
66
+ return self.cached_penc
67
+ else:
68
+ return self.cached_penc.unsqueeze(1)
69
+
70
+ self.cached_penc = None
71
+
72
+ pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
73
+ pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
74
+ if self.normalize:
75
+ pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
76
+ pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
77
+
78
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
79
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
80
+ emb_y = get_emb(sin_inp_y).unsqueeze(1)
81
+ emb_x = get_emb(sin_inp_x)
82
+
83
+ emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
84
+ emb[:, :, :self.dim] = emb_x
85
+ emb[:, :, self.dim:] = emb_y
86
+
87
+ if not self.channel_last and self.transpose_output:
88
+ # cancelled out
89
+ pass
90
+ elif (not self.channel_last) or (self.transpose_output):
91
+ emb = emb.permute(2, 0, 1)
92
+
93
+ self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
94
+ if num_objects is None:
95
+ return self.cached_penc
96
+ else:
97
+ return self.cached_penc.unsqueeze(1)
98
+
99
+
100
+ if __name__ == '__main__':
101
+ pe = PositionalEncoding(8).cuda()
102
+ input = torch.ones((1, 8, 8, 8)).cuda()
103
+ output = pe(input)
104
+ # print(output)
105
+ print(output[0, :, 0, 0])
106
+ print(output[0, :, 0, 5])
107
+ print(output[0, 0, :, 0])
108
+ print(output[0, 0, 0, :])
hf_space/third_party/matanyone/model/transformer/transformer_layers.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from PyTorch nn.Transformer
2
+
3
+ from typing import List, Callable
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from matanyone.model.channel_attn import CAResBlock
10
+
11
+
12
+ class SelfAttention(nn.Module):
13
+ def __init__(self,
14
+ dim: int,
15
+ nhead: int,
16
+ dropout: float = 0.0,
17
+ batch_first: bool = True,
18
+ add_pe_to_qkv: List[bool] = [True, True, False]):
19
+ super().__init__()
20
+ self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
21
+ self.norm = nn.LayerNorm(dim)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.add_pe_to_qkv = add_pe_to_qkv
24
+
25
+ def forward(self,
26
+ x: torch.Tensor,
27
+ pe: torch.Tensor,
28
+ attn_mask: bool = None,
29
+ key_padding_mask: bool = None) -> torch.Tensor:
30
+ x = self.norm(x)
31
+ if any(self.add_pe_to_qkv):
32
+ x_with_pe = x + pe
33
+ q = x_with_pe if self.add_pe_to_qkv[0] else x
34
+ k = x_with_pe if self.add_pe_to_qkv[1] else x
35
+ v = x_with_pe if self.add_pe_to_qkv[2] else x
36
+ else:
37
+ q = k = v = x
38
+
39
+ r = x
40
+ x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
41
+ return r + self.dropout(x)
42
+
43
+
44
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
45
+ class CrossAttention(nn.Module):
46
+ def __init__(self,
47
+ dim: int,
48
+ nhead: int,
49
+ dropout: float = 0.0,
50
+ batch_first: bool = True,
51
+ add_pe_to_qkv: List[bool] = [True, True, False],
52
+ residual: bool = True,
53
+ norm: bool = True):
54
+ super().__init__()
55
+ self.cross_attn = nn.MultiheadAttention(dim,
56
+ nhead,
57
+ dropout=dropout,
58
+ batch_first=batch_first)
59
+ if norm:
60
+ self.norm = nn.LayerNorm(dim)
61
+ else:
62
+ self.norm = nn.Identity()
63
+ self.dropout = nn.Dropout(dropout)
64
+ self.add_pe_to_qkv = add_pe_to_qkv
65
+ self.residual = residual
66
+
67
+ def forward(self,
68
+ x: torch.Tensor,
69
+ mem: torch.Tensor,
70
+ x_pe: torch.Tensor,
71
+ mem_pe: torch.Tensor,
72
+ attn_mask: bool = None,
73
+ *,
74
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
75
+ x = self.norm(x)
76
+ if self.add_pe_to_qkv[0]:
77
+ q = x + x_pe
78
+ else:
79
+ q = x
80
+
81
+ if any(self.add_pe_to_qkv[1:]):
82
+ mem_with_pe = mem + mem_pe
83
+ k = mem_with_pe if self.add_pe_to_qkv[1] else mem
84
+ v = mem_with_pe if self.add_pe_to_qkv[2] else mem
85
+ else:
86
+ k = v = mem
87
+ r = x
88
+ x, weights = self.cross_attn(q,
89
+ k,
90
+ v,
91
+ attn_mask=attn_mask,
92
+ need_weights=need_weights,
93
+ average_attn_weights=False)
94
+
95
+ if self.residual:
96
+ return r + self.dropout(x), weights
97
+ else:
98
+ return self.dropout(x), weights
99
+
100
+
101
+ class FFN(nn.Module):
102
+ def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
103
+ super().__init__()
104
+ self.linear1 = nn.Linear(dim_in, dim_ff)
105
+ self.linear2 = nn.Linear(dim_ff, dim_in)
106
+ self.norm = nn.LayerNorm(dim_in)
107
+
108
+ if isinstance(activation, str):
109
+ self.activation = _get_activation_fn(activation)
110
+ else:
111
+ self.activation = activation
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ r = x
115
+ x = self.norm(x)
116
+ x = self.linear2(self.activation(self.linear1(x)))
117
+ x = r + x
118
+ return x
119
+
120
+
121
+ class PixelFFN(nn.Module):
122
+ def __init__(self, dim: int):
123
+ super().__init__()
124
+ self.dim = dim
125
+ self.conv = CAResBlock(dim, dim)
126
+
127
+ def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
128
+ # pixel: batch_size * num_objects * dim * H * W
129
+ # pixel_flat: (batch_size*num_objects) * (H*W) * dim
130
+ bs, num_objects, _, h, w = pixel.shape
131
+ pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
132
+ pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
133
+
134
+ x = self.conv(pixel_flat)
135
+ x = x.view(bs, num_objects, self.dim, h, w)
136
+ return x
137
+
138
+
139
+ class OutputFFN(nn.Module):
140
+ def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
141
+ super().__init__()
142
+ self.linear1 = nn.Linear(dim_in, dim_out)
143
+ self.linear2 = nn.Linear(dim_out, dim_out)
144
+
145
+ if isinstance(activation, str):
146
+ self.activation = _get_activation_fn(activation)
147
+ else:
148
+ self.activation = activation
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ x = self.linear2(self.activation(self.linear1(x)))
152
+ return x
153
+
154
+
155
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
156
+ if activation == "relu":
157
+ return F.relu
158
+ elif activation == "gelu":
159
+ return F.gelu
160
+
161
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
hf_space/third_party/matanyone/model/utils/__init__.py ADDED
File without changes
hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (201 Bytes). View file
 
hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc ADDED
Binary file (4.66 kB). View file