diff --git a/hf_space/third_party/matanyone/__init__.py b/hf_space/third_party/matanyone/__init__.py index 400360ae84a31c4f18fece39d591e2f71ba2a2ab..439c961b22bf2373edb3d5d7cc5fb0570c1fd715 100644 --- a/hf_space/third_party/matanyone/__init__.py +++ b/hf_space/third_party/matanyone/__init__.py @@ -1,2 +1,2 @@ -# Placeholder for vendored MatAnyone package. Replace with the real 'matanyone' package from https://github.com/pq-yang/MatAnyone. -__all__ = [] +from matanyone.inference.inference_core import InferenceCore +from matanyone.model.matanyone import MatAnyone diff --git a/hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc b/hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6731a8b27e986433bbe82c37fe265614661de5bf Binary files /dev/null and b/hf_space/third_party/matanyone/__pycache__/__init__.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/config/__init__.py b/hf_space/third_party/matanyone/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/config/eval_matanyone_config.yaml b/hf_space/third_party/matanyone/config/eval_matanyone_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3bce31614048c539d739c18f63dbf39b1daf0b58 --- /dev/null +++ b/hf_space/third_party/matanyone/config/eval_matanyone_config.yaml @@ -0,0 +1,47 @@ +defaults: + - _self_ + - model: base + - override hydra/job_logging: custom-no-rank.yaml + +hydra: + run: + dir: ../output/${exp_id}/${dataset} + output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra + +amp: False +weights: pretrained_models/matanyone.pth # default (can be modified from outside) +output_dir: null # defaults to run_dir; specify this to override +flip_aug: False + + +# maximum shortest side of the input; -1 means no resizing +# With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader) +# this parameter is added for the sole purpose for the GUI in the current codebase +# InferenceCore will downsize the input and restore the output to the original size if needed +# if you are using this code for some other project, you can also utilize this parameter +max_internal_size: -1 + +# these parameters, when set, override the dataset's default; useful for debugging +save_all: True +use_all_masks: False +use_long_term: False +mem_every: 5 + +# only relevant when long_term is not enabled +max_mem_frames: 5 + +# only relevant when long_term is enabled +long_term: + count_usage: True + max_mem_frames: 10 + min_mem_frames: 5 + num_prototypes: 128 + max_num_tokens: 10000 + buffer_tokens: 2000 + +top_k: 30 +stagger_updates: 5 +chunk_size: -1 # number of objects to process in parallel; -1 means unlimited +save_scores: False +save_aux: False +visualize: False diff --git a/hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml b/hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b74d48989959861d9a5d06fa76c4f6070bf5936 --- /dev/null +++ b/hf_space/third_party/matanyone/config/hydra/job_logging/custom-no-rank.yaml @@ -0,0 +1,22 @@ +# python logging configuration for tasks +version: 1 +formatters: + simple: + format: '[%(asctime)s][%(levelname)s] - %(message)s' + datefmt: '%Y-%m-%d %H:%M:%S' +handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + # absolute file path + filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log + mode: w +root: + level: INFO + handlers: [console, file] + +disable_existing_loggers: false \ No newline at end of file diff --git a/hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml b/hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86fc06ac25870776acfa4acd03feed3e99157a24 --- /dev/null +++ b/hf_space/third_party/matanyone/config/hydra/job_logging/custom.yaml @@ -0,0 +1,22 @@ +# python logging configuration for tasks +version: 1 +formatters: + simple: + format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s' + datefmt: '%Y-%m-%d %H:%M:%S' +handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + # absolute file path + filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log + mode: w +root: + level: INFO + handlers: [console, file] + +disable_existing_loggers: false \ No newline at end of file diff --git a/hf_space/third_party/matanyone/config/model/base.yaml b/hf_space/third_party/matanyone/config/model/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4e8de055bc1bbdb94a3e46770995e4ab73ccbc3b --- /dev/null +++ b/hf_space/third_party/matanyone/config/model/base.yaml @@ -0,0 +1,58 @@ +pixel_mean: [0.485, 0.456, 0.406] +pixel_std: [0.229, 0.224, 0.225] + +pixel_dim: 256 +key_dim: 64 +value_dim: 256 +sensory_dim: 256 +embed_dim: 256 + +pixel_encoder: + type: resnet50 + ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1 + +mask_encoder: + type: resnet18 + final_dim: 256 + +pixel_pe_scale: 32 +pixel_pe_temperature: 128 + +object_transformer: + embed_dim: ${model.embed_dim} + ff_dim: 2048 + num_heads: 8 + num_blocks: 3 + num_queries: 16 + read_from_pixel: + input_norm: False + input_add_pe: False + add_pe_to_qkv: [True, True, False] + read_from_past: + add_pe_to_qkv: [True, True, False] + read_from_memory: + add_pe_to_qkv: [True, True, False] + read_from_query: + add_pe_to_qkv: [True, True, False] + output_norm: False + query_self_attention: + add_pe_to_qkv: [True, True, False] + pixel_self_attention: + add_pe_to_qkv: [True, True, False] + +object_summarizer: + embed_dim: ${model.object_transformer.embed_dim} + num_summaries: ${model.object_transformer.num_queries} + add_pe: True + +aux_loss: + sensory: + enabled: True + weight: 0.01 + query: + enabled: True + weight: 0.01 + +mask_decoder: + # first value must equal embed_dim + up_dims: [256, 128, 128, 64, 16] diff --git a/hf_space/third_party/matanyone/inference/__init__.py b/hf_space/third_party/matanyone/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87f3229809a3b5f4313a62e1c1c9eb74c81591a4 Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/__init__.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4844c865dd1e548ed36620b0d0cfad08621a57b Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/image_feature_store.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1885aa5af7d3a4cd2cfe131f23e98393f76541 Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/inference_core.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..623cdc031037f1a67ef70f7f683d31ada759f8d3 Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/kv_memory_store.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07cb59db582af710fbea08d730aef3038ef74458 Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/memory_manager.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e71b7a88d25dac6c4aa26157643f25c965c3f892 Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/object_info.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc b/hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e82f1f6eed1b25dfea2c2ee655b54a3b258f33 Binary files /dev/null and b/hf_space/third_party/matanyone/inference/__pycache__/object_manager.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/inference/image_feature_store.py b/hf_space/third_party/matanyone/inference/image_feature_store.py new file mode 100644 index 0000000000000000000000000000000000000000..590456f4877670da9bba09f8695467d4b97cf69d --- /dev/null +++ b/hf_space/third_party/matanyone/inference/image_feature_store.py @@ -0,0 +1,56 @@ +import warnings +from typing import Iterable +import torch +from matanyone.model.matanyone import MatAnyone + + +class ImageFeatureStore: + """ + A cache for image features. + These features might be reused at different parts of the inference pipeline. + This class provide an interface for reusing these features. + It is the user's responsibility to delete redundant features. + + Feature of a frame should be associated with a unique index -- typically the frame id. + """ + def __init__(self, network: MatAnyone, no_warning: bool = False): + self.network = network + self._store = {} + self.no_warning = no_warning + + def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None: + ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats) + key, shrinkage, selection = self.network.transform_key(ms_features[0]) + self._store[index] = (ms_features, pix_feat, key, shrinkage, selection) + + def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): + seq_length = images.shape[0] + ms_features, pix_feat = self.network.encode_image(images, seq_length) + key, shrinkage, selection = self.network.transform_key(ms_features[0]) + for index in range(seq_length): + self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0)) + + def get_features(self, index: int, + image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): + if index not in self._store: + self._encode_feature(index, image, last_feats) + + return self._store[index][:2] + + def get_key(self, index: int, + image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): + if index not in self._store: + self._encode_feature(index, image, last_feats) + + return self._store[index][2:] + + def delete(self, index: int) -> None: + if index in self._store: + del self._store[index] + + def __len__(self): + return len(self._store) + + def __del__(self): + if len(self._store) > 0 and not self.no_warning: + warnings.warn(f'Leaking {self._store.keys()} in the image feature store') diff --git a/hf_space/third_party/matanyone/inference/inference_core.py b/hf_space/third_party/matanyone/inference/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..88368614cf673d2b325038ec7b4f2d9e5f7c44f4 --- /dev/null +++ b/hf_space/third_party/matanyone/inference/inference_core.py @@ -0,0 +1,545 @@ +import logging +from omegaconf import DictConfig +from typing import List, Optional, Iterable, Union,Tuple + +import os +import cv2 +import torch +import imageio +import tempfile +import numpy as np +from tqdm import tqdm +from PIL import Image +import torch.nn.functional as F + +from matanyone.inference.memory_manager import MemoryManager +from matanyone.inference.object_manager import ObjectManager +from matanyone.inference.image_feature_store import ImageFeatureStore +from matanyone.model.matanyone import MatAnyone +from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate +from matanyone.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos + +log = logging.getLogger() + + +class InferenceCore: + + def __init__(self, + network: Union[MatAnyone,str], + cfg: DictConfig = None, + *, + image_feature_store: ImageFeatureStore = None, + device: Union[str, torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ): + if isinstance(network, str): + network = MatAnyone.from_pretrained(network) + network.to(device) + network.eval() + self.network = network + cfg = cfg if cfg is not None else network.cfg + self.cfg = cfg + self.mem_every = cfg.mem_every + stagger_updates = cfg.stagger_updates + self.chunk_size = cfg.chunk_size + self.save_aux = cfg.save_aux + self.max_internal_size = cfg.max_internal_size + self.flip_aug = cfg.flip_aug + + self.curr_ti = -1 + self.last_mem_ti = 0 + # at which time indices should we update the sensory memory + if stagger_updates >= self.mem_every: + self.stagger_ti = set(range(1, self.mem_every + 1)) + else: + self.stagger_ti = set( + np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int)) + self.object_manager = ObjectManager() + self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager) + + if image_feature_store is None: + self.image_feature_store = ImageFeatureStore(self.network) + else: + self.image_feature_store = image_feature_store + + self.last_mask = None + self.last_pix_feat = None + self.last_msk_value = None + + def clear_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager) + + def clear_non_permanent_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_sensory_memory() + + def update_config(self, cfg): + self.mem_every = cfg['mem_every'] + self.memory.update_config(cfg) + + def clear_temp_mem(self): + self.memory.clear_work_mem() + # self.object_manager = ObjectManager() + self.memory.clear_obj_mem() + # self.memory.clear_sensory_memory() + + def _add_memory(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + prob: torch.Tensor, + key: torch.Tensor, + shrinkage: torch.Tensor, + selection: torch.Tensor, + *, + is_deep_update: bool = True, + force_permanent: bool = False) -> None: + """ + Memorize the given segmentation in all memory stores. + + The batch dimension is 1 if flip augmentation is not used. + image: RGB image, (1/2)*3*H*W + pix_feat: from the key encoder, (1/2)*_*H*W + prob: (1/2)*num_objects*H*W, in [0, 1] + key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W + selection can be None if not using long-term memory + is_deep_update: whether to use deep update (e.g. with the mask encoder) + force_permanent: whether to force the memory to be permanent + """ + if prob.shape[1] == 0: + # nothing to add + log.warn('Trying to add an empty object mask to memory!') + return + + if force_permanent: + as_permanent = 'all' + else: + as_permanent = 'first' + + self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids) + msk_value, sensory, obj_value, _ = self.network.encode_mask( + image, + pix_feat, + self.memory.get_sensory(self.object_manager.all_obj_ids), + prob, + deep_update=is_deep_update, + chunk_size=self.chunk_size, + need_weights=self.save_aux) + self.memory.add_memory(key, + shrinkage, + msk_value, + obj_value, + self.object_manager.all_obj_ids, + selection=selection, + as_permanent=as_permanent) + self.last_mem_ti = self.curr_ti + if is_deep_update: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + self.last_msk_value = msk_value + + def _segment(self, + key: torch.Tensor, + selection: torch.Tensor, + pix_feat: torch.Tensor, + ms_features: Iterable[torch.Tensor], + update_sensory: bool = True) -> torch.Tensor: + """ + Produce a segmentation using the given features and the memory + + The batch dimension is 1 if flip augmentation is not used. + key/selection: for anisotropic l2: (1/2) * _ * H * W + pix_feat: from the key encoder, (1/2) * _ * H * W + ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W + with strides 16, 8, and 4 respectively + update_sensory: whether to update the sensory memory + + Returns: (num_objects+1)*H*W normalized probability; the first channel is the background + """ + bs = key.shape[0] + if self.flip_aug: + assert bs == 2 + else: + assert bs == 1 + + if not self.memory.engaged: + log.warn('Trying to segment without any memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + + uncert_output = None + + if self.curr_ti == 0: # ONLY for the first frame for prediction + memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output) + else: + memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti, + last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask) + memory_readout = self.object_manager.realize_dict(memory_readout) + + sensory, _, pred_prob_with_bg = self.network.segment(ms_features, + memory_readout, + self.memory.get_sensory( + self.object_manager.all_obj_ids), + chunk_size=self.chunk_size, + update_sensory=update_sensory) + # remove batch dim + if self.flip_aug: + # average predictions of the non-flipped and flipped version + pred_prob_with_bg = (pred_prob_with_bg[0] + + torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2 + else: + pred_prob_with_bg = pred_prob_with_bg[0] + if update_sensory: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + return pred_prob_with_bg + + def pred_all_flow(self, images): + self.total_len = images.shape[0] + images, self.pad = pad_divide_by(images, 16) + images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w) + + self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images) + + def encode_all_images(self, images): + images, self.pad = pad_divide_by(images, 16) + self.image_feature_store.get_all_features(images) # t c h w + return images + + def step(self, + image: torch.Tensor, + mask: Optional[torch.Tensor] = None, + objects: Optional[List[int]] = None, + *, + idx_mask: bool = False, + end: bool = False, + delete_buffer: bool = True, + force_permanent: bool = False, + matting: bool = True, + first_frame_pred: bool = False) -> torch.Tensor: + """ + Take a step with a new incoming image. + If there is an incoming mask with new objects, we will memorize them. + If there is no incoming mask, we will segment the image using the memory. + In both cases, we will update the memory and return a segmentation. + + image: 3*H*W + mask: H*W (if idx mask) or len(objects)*H*W or None + objects: list of object ids that are valid in the mask Tensor. + The ids themselves do not need to be consecutive/in order, but they need to be + in the same position in the list as the corresponding mask + in the tensor in non-idx-mask mode. + objects is ignored if the mask is None. + If idx_mask is False and objects is None, we sequentially infer the object ids. + idx_mask: if True, mask is expected to contain an object id at every pixel. + If False, mask should have multiple channels with each channel representing one object. + end: if we are at the end of the sequence, we do not need to update memory + if unsure just set it to False + delete_buffer: whether to delete the image feature buffer after this step + force_permanent: the memory recorded this frame will be added to the permanent memory + """ + if objects is None and mask is not None: + assert not idx_mask + objects = list(range(1, mask.shape[0] + 1)) + + # resize input if needed -- currently only used for the GUI + resize_needed = False + if self.max_internal_size > 0: + h, w = image.shape[-2:] + min_side = min(h, w) + if min_side > self.max_internal_size: + resize_needed = True + new_h = int(h / min_side * self.max_internal_size) + new_w = int(w / min_side * self.max_internal_size) + image = F.interpolate(image.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + if mask is not None: + if idx_mask: + mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), + size=(new_h, new_w), + mode='nearest-exact', + align_corners=False)[0, 0].round().long() + else: + mask = F.interpolate(mask.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + + self.curr_ti += 1 + + image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!! + image = image.unsqueeze(0) # add the batch dimension + if self.flip_aug: + image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0) + + # whether to update the working memory + is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or + (mask is not None)) and (not end) + # segment when there is no input mask or when the input mask is incomplete + need_segment = (mask is None) or (self.object_manager.num_obj > 0 + and not self.object_manager.has_all(objects)) + update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end) + + # reinit if it is the first frame for prediction + if first_frame_pred: + self.curr_ti = 0 + self.last_mem_ti = 0 + is_mem_frame = True + need_segment = True + update_sensory = True + + # encoding the image + ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) + key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image) + + # segmentation from memory if needed + if need_segment: + pred_prob_with_bg = self._segment(key, + selection, + pix_feat, + ms_feat, + update_sensory=update_sensory) + + # use the input mask if provided + if mask is not None: + # inform the manager of the new objects, and get a list of temporary id + # temporary ids -- indicates the position of objects in the tensor + # (starts with 1 due to the background channel) + corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects) + + mask, _ = pad_divide_by(mask, 16) + if need_segment: + # merge predicted mask with the incomplete input mask + pred_prob_no_bg = pred_prob_with_bg[1:] + # use the mutual exclusivity of segmentation + if idx_mask: + pred_prob_no_bg[:, mask > 0] = 0 + else: + pred_prob_no_bg[:, mask.max(0) > 0.5] = 0 + + new_masks = [] + for mask_id, tmp_id in enumerate(corresponding_tmp_ids): + if idx_mask: + this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg) + else: + this_mask = mask[tmp_id] + if tmp_id > pred_prob_no_bg.shape[0]: + new_masks.append(this_mask.unsqueeze(0)) + else: + # +1 for padding the background channel + pred_prob_no_bg[tmp_id - 1] = this_mask + # new_masks are always in the order of tmp_id + mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0) + elif idx_mask: + # simply convert cls to one-hot representation + if len(objects) == 0: + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + log.warn('Trying to insert an empty mask as memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + mask = torch.stack( + [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)], + dim=0) + if matting: + mask = mask.unsqueeze(0).float() / 255. + pred_prob_with_bg = torch.cat([1-mask, mask], 0) + else: + pred_prob_with_bg = aggregate(mask, dim=0) + pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0) + + self.last_mask = pred_prob_with_bg[1:].unsqueeze(0) + if self.flip_aug: + self.last_mask = torch.cat( + [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0) + self.last_pix_feat = pix_feat + + # save as memory if needed + if is_mem_frame or force_permanent: + # clear the memory for given mask and add the first predicted mask + if first_frame_pred: + self.clear_temp_mem() + self._add_memory(image, + pix_feat, + self.last_mask, + key, + shrinkage, + selection, + force_permanent=force_permanent, + is_deep_update=True) + else: # compute self.last_msk_value for non-memory frame + msk_value, _, _, _ = self.network.encode_mask( + image, + pix_feat, + self.memory.get_sensory(self.object_manager.all_obj_ids), + self.last_mask, + deep_update=False, + chunk_size=self.chunk_size, + need_weights=self.save_aux) + self.last_msk_value = msk_value + + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + + output_prob = unpad(pred_prob_with_bg, self.pad) + if resize_needed: + # restore output to the original size + output_prob = F.interpolate(output_prob.unsqueeze(0), + size=(h, w), + mode='bilinear', + align_corners=False)[0] + + return output_prob + + def delete_objects(self, objects: List[int]) -> None: + """ + Delete the given objects from the memory. + """ + self.object_manager.delete_objects(objects) + self.memory.purge_except(self.object_manager.all_obj_ids) + + def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor: + if matting: + new_mask = output_prob[1:].squeeze(0) + else: + mask = torch.argmax(output_prob, dim=0) + + # index in tensor != object id -- remap the ids here + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + + return new_mask + + @torch.inference_mode() + @torch.amp.autocast("cuda") + def process_video( + self, + input_path: str, + mask_path: str, + output_path: str = None, + n_warmup: int = 10, + r_erode: int = 10, + r_dilate: int = 10, + suffix: str = "", + save_image: bool = False, + max_size: int = -1, + ) -> Tuple: + """ + Process a video for object segmentation and matting. + This method processes a video file by performing object segmentation and matting on each frame. + It supports warmup frames, mask erosion/dilation, and various output options. + Args: + input_path (str): Path to the input video file + mask_path (str): Path to the mask image file used for initial segmentation + output_path (str, optional): Directory path where output files will be saved. Defaults to a temporary directory + n_warmup (int, optional): Number of warmup frames to use. Defaults to 10 + r_erode (int, optional): Erosion radius for mask processing. Defaults to 10 + r_dilate (int, optional): Dilation radius for mask processing. Defaults to 10 + suffix (str, optional): Suffix to append to output filename. Defaults to "" + save_image (bool, optional): Whether to save individual frames. Defaults to False + max_size (int, optional): Maximum size for frame dimension. Use -1 for no limit. Defaults to -1 + Returns: + Tuple[str, str]: A tuple containing: + - Path to the output foreground video file (str) + - Path to the output alpha matte video file (str) + Output: + - Saves processed video files with foreground (_fgr) and alpha matte (_pha) + - If save_image=True, saves individual frames in separate directories + """ + output_path = output_path if output_path is not None else tempfile.TemporaryDirectory().name + r_erode = int(r_erode) + r_dilate = int(r_dilate) + n_warmup = int(n_warmup) + max_size = int(max_size) + + vframes, fps, length, video_name = read_frame_from_videos(input_path) + repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1) + vframes = torch.cat([repeated_frames, vframes], dim=0).float() + length += n_warmup + + new_h, new_w = vframes.shape[-2:] + if max_size > 0: + h, w = new_h, new_w + min_side = min(h, w) + if min_side > max_size: + new_h = int(h / min_side * max_size) + new_w = int(w / min_side * max_size) + vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area") + + os.makedirs(output_path, exist_ok=True) + if suffix: + video_name = f"{video_name}_{suffix}" + if save_image: + os.makedirs(f"{output_path}/{video_name}", exist_ok=True) + os.makedirs(f"{output_path}/{video_name}/pha", exist_ok=True) + os.makedirs(f"{output_path}/{video_name}/fgr", exist_ok=True) + + mask = np.array(Image.open(mask_path).convert("L")) + if r_dilate > 0: + mask = gen_dilate(mask, r_dilate, r_dilate) + if r_erode > 0: + mask = gen_erosion(mask, r_erode, r_erode) + + mask = torch.from_numpy(mask).cuda() + if max_size > 0: + mask = F.interpolate( + mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest" + )[0, 0] + + bgr = (np.array([120, 255, 155], dtype=np.float32) / 255).reshape((1, 1, 3)) + objects = [1] + + phas = [] + fgrs = [] + for ti in tqdm(range(length)): + image = vframes[ti] + image_np = np.array(image.permute(1, 2, 0)) + image = (image / 255.0).cuda().float() + + if ti == 0: + output_prob = self.step(image, mask, objects=objects) + output_prob = self.step(image, first_frame_pred=True) + else: + if ti <= n_warmup: + output_prob = self.step(image, first_frame_pred=True) + else: + output_prob = self.step(image) + + mask = self.output_prob_to_mask(output_prob) + pha = mask.unsqueeze(2).cpu().numpy() + com_np = image_np / 255.0 * pha + bgr * (1 - pha) + + if ti > (n_warmup - 1): + com_np = (com_np * 255).astype(np.uint8) + pha = (pha * 255).astype(np.uint8) + fgrs.append(com_np) + phas.append(pha) + if save_image: + cv2.imwrite( + f"{output_path}/{video_name}/pha/{str(ti - n_warmup).zfill(5)}.png", + pha, + ) + cv2.imwrite( + f"{output_path}/{video_name}/fgr/{str(ti - n_warmup).zfill(5)}.png", + com_np[..., [2, 1, 0]], + ) + + fgrs = np.array(fgrs) + phas = np.array(phas) + + fgr_filename = f"{output_path}/{video_name}_fgr.mp4" + alpha_filename = f"{output_path}/{video_name}_pha.mp4" + + imageio.mimwrite(fgr_filename, fgrs, fps=fps, quality=7) + imageio.mimwrite(alpha_filename, phas, fps=fps, quality=7) + + return (fgr_filename,alpha_filename) diff --git a/hf_space/third_party/matanyone/inference/kv_memory_store.py b/hf_space/third_party/matanyone/inference/kv_memory_store.py new file mode 100644 index 0000000000000000000000000000000000000000..67dba3df9d1c2b1fd0b07619bfc6d9b90a8182d2 --- /dev/null +++ b/hf_space/third_party/matanyone/inference/kv_memory_store.py @@ -0,0 +1,348 @@ +from typing import Dict, List, Optional, Literal +from collections import defaultdict +import torch + + +def _add_last_dim(dictionary, key, new_value, prepend=False): + # append/prepend a new value to the last dimension of a tensor in a dictionary + # if the key does not exist, put the new value in + # append by default + if key in dictionary: + dictionary[key] = torch.cat([dictionary[key], new_value], -1) + else: + dictionary[key] = new_value + + +class KeyValueMemoryStore: + """ + Works for key/value pairs type storage + e.g., working and long-term memory + """ + def __init__(self, save_selection: bool = False, save_usage: bool = False): + """ + We store keys and values of objects that first appear in the same frame in a bucket. + Each bucket contains a set of object ids. + Each bucket is associated with a single key tensor + and a dictionary of value tensors indexed by object id. + + The keys and values are stored as the concatenation of a permanent part and a temporary part. + """ + self.save_selection = save_selection + self.save_usage = save_usage + + self.global_bucket_id = 0 # does not reduce even if buckets are removed + self.buckets: Dict[int, List[int]] = {} # indexed by bucket id + self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id + self.v: Dict[int, torch.Tensor] = {} # indexed by object id + + # indexed by bucket id; the end point of permanent memory + self.perm_end_pt: Dict[int, int] = defaultdict(int) + + # shrinkage and selection are just like the keys + self.s = {} + if self.save_selection: + self.e = {} # does not contain the permanent memory part + + # usage + if self.save_usage: + self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part + self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part + + def add(self, + key: torch.Tensor, + values: Dict[int, torch.Tensor], + shrinkage: torch.Tensor, + selection: torch.Tensor, + supposed_bucket_id: int = -1, + as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: + """ + key: (1/2)*C*N + values: dict of values ((1/2)*C*N), object ids are used as keys + shrinkage: (1/2)*1*N + selection: (1/2)*C*N + + supposed_bucket_id: used to sync the bucket id between working and long-term memory + if provided, the input should all be in a single bucket indexed by this id + as_permanent: whether to store the input as permanent memory + 'no': don't + 'first': only store it as permanent memory if the bucket is empty + 'all': always store it as permanent memory + """ + bs = key.shape[0] + ne = key.shape[-1] + assert len(key.shape) == 3 + assert len(shrinkage.shape) == 3 + assert not self.save_selection or len(selection.shape) == 3 + assert as_permanent in ['no', 'first', 'all'] + + # add the value and create new buckets if necessary + if supposed_bucket_id >= 0: + enabled_buckets = [supposed_bucket_id] + bucket_exist = supposed_bucket_id in self.buckets + for obj, value in values.items(): + if bucket_exist: + assert obj in self.v + assert obj in self.buckets[supposed_bucket_id] + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + else: + assert obj not in self.v + self.v[obj] = value + self.buckets[supposed_bucket_id] = list(values.keys()) + else: + new_bucket_id = None + enabled_buckets = set() + for obj, value in values.items(): + assert len(value.shape) == 3 + if obj in self.v: + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + bucket_used = [ + bucket_id for bucket_id, object_ids in self.buckets.items() + if obj in object_ids + ] + assert len(bucket_used) == 1 # each object should only be in one bucket + enabled_buckets.add(bucket_used[0]) + else: + self.v[obj] = value + if new_bucket_id is None: + # create new bucket + new_bucket_id = self.global_bucket_id + self.global_bucket_id += 1 + self.buckets[new_bucket_id] = [] + # put the new object into the corresponding bucket + self.buckets[new_bucket_id].append(obj) + enabled_buckets.add(new_bucket_id) + + # increment the permanent size if necessary + add_as_permanent = {} # indexed by bucket id + for bucket_id in enabled_buckets: + add_as_permanent[bucket_id] = False + if as_permanent == 'all': + self.perm_end_pt[bucket_id] += ne + add_as_permanent[bucket_id] = True + elif as_permanent == 'first': + if self.perm_end_pt[bucket_id] == 0: + self.perm_end_pt[bucket_id] = ne + add_as_permanent[bucket_id] = True + + # create new counters for usage if necessary + if self.save_usage and as_permanent != 'all': + new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 + + # add the key to every bucket + for bucket_id in self.buckets: + if bucket_id not in enabled_buckets: + # if we are not adding new values to a bucket, we should skip it + continue + + _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) + _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) + if not add_as_permanent[bucket_id]: + if self.save_selection: + _add_last_dim(self.e, bucket_id, selection) + if self.save_usage: + _add_last_dim(self.use_cnt, bucket_id, new_count) + _add_last_dim(self.life_cnt, bucket_id, new_life) + + def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: + # increase all life count by 1 + # increase use of indexed elements + if not self.save_usage: + return + + usage = usage[:, self.perm_end_pt[bucket_id]:] + if usage.shape[-1] == 0: + # if there is no temporary memory, we don't need to update + return + self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) + self.life_cnt[bucket_id] += 1 + + def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: + # keep only the temporary elements *outside* of this range (with some boundary conditions) + # the permanent elements are ignored in this computation + # i.e., concat (a[:start], a[end:]) + # bucket with size <= min_size are not modified + + assert start >= 0 + assert end <= 0 + + object_ids = self.buckets[bucket_id] + bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] + if bucket_num_elements <= min_size: + return + + if end == 0: + # negative 0 would not work as the end index! + # effectively make the second part an empty slice + end = self.k[bucket_id].shape[-1] + 1 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + k = self.k[bucket_id] + s = self.s[bucket_id] + if self.save_selection: + e = self.e[bucket_id] + if self.save_usage: + use_cnt = self.use_cnt[bucket_id] + life_cnt = self.life_cnt[bucket_id] + + self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) + self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) + if self.save_selection: + self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) + if self.save_usage: + self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) + self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], + -1) + for obj_id in object_ids: + v = self.v[obj_id] + self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) + + def remove_old_memory(self, bucket_id: int, max_len: int) -> None: + self.sieve_by_range(bucket_id, 0, -max_len, max_len) + + def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: + # for long-term memory only + object_ids = self.buckets[bucket_id] + + assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory + + # normalize with life duration + usage = self.get_usage(bucket_id) + bs = usage.shape[0] + + survivals = [] + + for bi in range(bs): + _, survived = torch.topk(usage[bi], k=max_size) + survivals.append(survived.flatten()) + assert survived.shape[-1] == survivals[0].shape[-1] + + self.k[bucket_id] = torch.stack( + [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + self.s[bucket_id] = torch.stack( + [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + if self.save_selection: + # Long-term memory does not store selection so this should not be needed + self.e[bucket_id] = torch.stack( + [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + for obj_id in object_ids: + self.v[obj_id] = torch.stack( + [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + self.use_cnt[bucket_id] = torch.stack( + [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + self.life_cnt[bucket_id] = torch.stack( + [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + + def get_usage(self, bucket_id: int) -> torch.Tensor: + # return normalized usage + if not self.save_usage: + raise RuntimeError('I did not count usage!') + else: + usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] + return usage + + def get_all_sliced( + self, bucket_id: int, start: int, end: int + ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # return k, sk, ek, value, normalized usage in order, sliced by start and end + # this only queries the temporary memory + + assert start >= 0 + assert end <= 0 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + if end == 0: + # negative 0 would not work as the end index! + k = self.k[bucket_id][:, :, start:] + sk = self.s[bucket_id][:, :, start:] + ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None + else: + k = self.k[bucket_id][:, :, start:end] + sk = self.s[bucket_id][:, :, start:end] + ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None + + return k, sk, ek, value, usage + + def purge_except(self, obj_keep_idx: List[int]): + # purge certain objects from the memory except the one listed + obj_keep_idx = set(obj_keep_idx) + + # remove objects that are not in the keep list from the buckets + buckets_to_remove = [] + for bucket_id, object_ids in self.buckets.items(): + self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] + if len(self.buckets[bucket_id]) == 0: + buckets_to_remove.append(bucket_id) + + # remove object values that are not in the keep list + self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} + + # remove buckets that are empty + for bucket_id in buckets_to_remove: + del self.buckets[bucket_id] + del self.k[bucket_id] + del self.s[bucket_id] + if self.save_selection: + del self.e[bucket_id] + if self.save_usage: + del self.use_cnt[bucket_id] + del self.life_cnt[bucket_id] + + def clear_non_permanent_memory(self): + # clear all non-permanent memory + for bucket_id in self.buckets: + self.sieve_by_range(bucket_id, 0, 0, 0) + + def get_v_size(self, obj_id: int) -> int: + return self.v[obj_id].shape[-1] + + def size(self, bucket_id: int) -> int: + if bucket_id not in self.k: + return 0 + else: + return self.k[bucket_id].shape[-1] + + def perm_size(self, bucket_id: int) -> int: + return self.perm_end_pt[bucket_id] + + def non_perm_size(self, bucket_id: int) -> int: + return self.size(bucket_id) - self.perm_size(bucket_id) + + def engaged(self, bucket_id: Optional[int] = None) -> bool: + if bucket_id is None: + return len(self.buckets) > 0 + else: + return bucket_id in self.buckets + + @property + def num_objects(self) -> int: + return len(self.v) + + @property + def key(self) -> Dict[int, torch.Tensor]: + return self.k + + @property + def value(self) -> Dict[int, torch.Tensor]: + return self.v + + @property + def shrinkage(self) -> Dict[int, torch.Tensor]: + return self.s + + @property + def selection(self) -> Dict[int, torch.Tensor]: + return self.e + + def __contains__(self, key): + return key in self.v diff --git a/hf_space/third_party/matanyone/inference/memory_manager.py b/hf_space/third_party/matanyone/inference/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3504074f129df7b2fed4802f40c9fd9d4a94b686 --- /dev/null +++ b/hf_space/third_party/matanyone/inference/memory_manager.py @@ -0,0 +1,453 @@ +import logging +from omegaconf import DictConfig +from typing import List, Dict +import torch + +from matanyone.inference.object_manager import ObjectManager +from matanyone.inference.kv_memory_store import KeyValueMemoryStore +from matanyone.model.matanyone import MatAnyone +from matanyone.model.utils.memory_utils import get_similarity, do_softmax + +log = logging.getLogger() + + +class MemoryManager: + """ + Manages all three memory stores and the transition between working/long-term memory + """ + def __init__(self, cfg: DictConfig, object_manager: ObjectManager): + self.object_manager = object_manager + self.sensory_dim = cfg.model.sensory_dim + self.top_k = cfg.top_k + self.chunk_size = cfg.chunk_size + + self.save_aux = cfg.save_aux + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + # subtract 1 because the first-frame is now counted as "permanent memory" + # and is not counted towards max_mem_frames + # but we want to keep the hyperparameters consistent as before for the same behavior + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + # dimensions will be inferred from input later + self.CK = self.CV = None + self.H = self.W = None + + # The sensory memory is stored as a dictionary indexed by object ids + # each of shape bs * C^h * H * W + self.sensory = {} + + # a dictionary indexed by object ids, each of shape bs * T * Q * C + self.obj_v = {} + + self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term, + save_usage=self.use_long_term) + if self.use_long_term: + self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage) + + self.config_stale = True + self.engaged = False + + def update_config(self, cfg: DictConfig) -> None: + self.config_stale = True + self.top_k = cfg['top_k'] + + assert self.use_long_term == cfg.use_long_term, 'cannot update this' + assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this' + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor: + # affinity: bs*N*HW + # v: bs*C*N or bs*num_objects*C*N + # returns bs*C*HW or bs*num_objects*C*HW + if len(v.shape) == 3: + # single object + if uncert_mask is not None: + return v @ affinity * uncert_mask + else: + return v @ affinity + else: + bs, num_objects, C, N = v.shape + v = v.view(bs, num_objects * C, N) + out = v @ affinity + if uncert_mask is not None: + uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1) + out = out * uncert_mask + return out.view(bs, num_objects, C, -1) + + def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor: + # -1 because the mask does not contain the background channel + return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]] + + def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1) + + def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1) + + def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + # All the values that the object ids refer to should have the same shape + value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1) + if self.use_long_term and obj_ids[0] in self.long_mem.value: + lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1) + value = torch.cat([lt_value, value], dim=-1) + + return value + + def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor, + last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]: + """ + Read from all memory stores and returns a single memory readout tensor for each object + + pix_feat: (1/2) x C x H x W + query_key: (1/2) x C^k x H x W + selection: (1/2) x C^k x H x W + last_mask: (1/2) x num_objects x H x W (at stride 16) + return a dict of memory readouts, indexed by object indices. Each readout is C*H*W + """ + h, w = pix_feat.shape[-2:] + bs = pix_feat.shape[0] + assert last_mask.shape[0] == bs + + """ + Compute affinity and perform readout + """ + all_readout_mem = {} + buckets = self.work_mem.buckets + for bucket_id, bucket in buckets.items(): + + if self.chunk_size < 1: + object_chunks = [bucket] + else: + object_chunks = [ + bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size) + ] + + for objects in object_chunks: + this_sensory = self._get_sensory_by_ids(objects) + this_last_mask = self._get_mask_by_ids(last_mask, objects) + this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N + pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory, + this_last_mask) + this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2) + readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem) + for i, obj in enumerate(objects): + all_readout_mem[obj] = readout_memory[:, i] + + if self.save_aux: + aux_output = { + # 'sensory': this_sensory, + # 'pixel_readout': pixel_readout, + 'q_logits': aux_features['logits'] if aux_features else None, + # 'q_weights': aux_features['q_weights'] if aux_features else None, + # 'p_weights': aux_features['p_weights'] if aux_features else None, + # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None, + } + self.aux = aux_output + + return all_readout_mem + + def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor, + last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None, + last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]: + """ + Read from all memory stores and returns a single memory readout tensor for each object + + pix_feat: (1/2) x C x H x W + query_key: (1/2) x C^k x H x W + selection: (1/2) x C^k x H x W + last_mask: (1/2) x num_objects x H x W (at stride 16) + return a dict of memory readouts, indexed by object indices. Each readout is C*H*W + """ + h, w = pix_feat.shape[-2:] + bs = pix_feat.shape[0] + assert query_key.shape[0] == bs + assert selection.shape[0] == bs + assert last_mask.shape[0] == bs + + uncert_mask = uncert_output["mask"] if uncert_output is not None else None + + query_key = query_key.flatten(start_dim=2) # bs*C^k*HW + selection = selection.flatten(start_dim=2) # bs*C^k*HW + """ + Compute affinity and perform readout + """ + all_readout_mem = {} + buckets = self.work_mem.buckets + for bucket_id, bucket in buckets.items(): + if self.use_long_term and self.long_mem.engaged(bucket_id): + # Use long-term memory + long_mem_size = self.long_mem.size(bucket_id) + memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]], + -1) + shrinkage = torch.cat( + [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1) + + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + """ + Record memory usage for working and long-term memory + """ + # ignore the index return for long-term memory + work_usage = usage[:, long_mem_size:] + self.work_mem.update_bucket_usage(bucket_id, work_usage) + + if self.count_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_bucket_usage(bucket_id, long_usage) + else: + # no long-term memory + memory_key = self.work_mem.key[bucket_id] + shrinkage = self.work_mem.shrinkage[bucket_id] + similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask) + + if self.use_long_term: + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + self.work_mem.update_bucket_usage(bucket_id, usage) + else: + affinity = do_softmax(similarity, top_k=self.top_k, inplace=True) + + if self.chunk_size < 1: + object_chunks = [bucket] + else: + object_chunks = [ + bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size) + ] + + for objects in object_chunks: + this_sensory = self._get_sensory_by_ids(objects) + this_last_mask = self._get_mask_by_ids(last_mask, objects) + this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N + visual_readout = self._readout(affinity, + this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w) + + uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0]) + + if uncert_output is not None: + uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w + visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob) + + pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory, + this_last_mask) + this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2) + readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem) + for i, obj in enumerate(objects): + all_readout_mem[obj] = readout_memory[:, i] + + if self.save_aux: + aux_output = { + # 'sensory': this_sensory, + # 'pixel_readout': pixel_readout, + 'q_logits': aux_features['logits'] if aux_features else None, + # 'q_weights': aux_features['q_weights'] if aux_features else None, + # 'p_weights': aux_features['p_weights'] if aux_features else None, + # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None, + } + self.aux = aux_output + + return all_readout_mem + + def add_memory(self, + key: torch.Tensor, + shrinkage: torch.Tensor, + msk_value: torch.Tensor, + obj_value: torch.Tensor, + objects: List[int], + selection: torch.Tensor = None, + *, + as_permanent: bool = False) -> None: + # key: (1/2)*C*H*W + # msk_value: (1/2)*num_objects*C*H*W + # obj_value: (1/2)*num_objects*Q*C + # objects contains a list of object ids corresponding to the objects in msk_value/obj_value + bs = key.shape[0] + assert shrinkage.shape[0] == bs + assert msk_value.shape[0] == bs + assert obj_value.shape[0] == bs + + self.engaged = True + if self.H is None or self.config_stale: + self.config_stale = False + self.H, self.W = msk_value.shape[-2:] + self.HW = self.H * self.W + # convert from num. frames to num. tokens + self.max_work_tokens = self.max_mem_frames * self.HW + if self.use_long_term: + self.min_work_tokens = self.min_mem_frames * self.HW + + # key: bs*C*N + # value: bs*num_objects*C*N + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + self.CK = key.shape[1] + + msk_value = msk_value.flatten(start_dim=3) + self.CV = msk_value.shape[2] + + if selection is not None: + # not used in non-long-term mode + selection = selection.flatten(start_dim=2) + + # insert object values into object memory + for obj_id, obj in enumerate(objects): + if obj in self.obj_v: + """streaming average + each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1) + first embed_dim keeps track of the sum of embeddings + the last dim keeps the total count + averaging in done inside the object transformer + + incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1) + self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0) + """ + last_acc = self.obj_v[obj][:, :, -1] + new_acc = last_acc + obj_value[:, obj_id, :, -1] + + self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] + + obj_value[:, obj_id, :, :-1]) + self.obj_v[obj][:, :, -1] = new_acc + else: + self.obj_v[obj] = obj_value[:, obj_id] + + # convert mask value tensor into a dict for insertion + msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)} + self.work_mem.add(key, + msk_values, + shrinkage, + selection=selection, + as_permanent=as_permanent) + + for bucket_id in self.work_mem.buckets.keys(): + # long-term memory cleanup + if self.use_long_term: + # Do memory compressed if needed + if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens: + # Remove obsolete features if needed + if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens - + self.num_prototypes): + self.long_mem.remove_obsolete_features( + bucket_id, + self.max_long_tokens - self.num_prototypes - self.buffer_tokens) + + self.compress_features(bucket_id) + else: + # FIFO + self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens) + + def purge_except(self, obj_keep_idx: List[int]) -> None: + # purge certain objects from the memory except the one listed + self.work_mem.purge_except(obj_keep_idx) + if self.use_long_term and self.long_mem.engaged(): + self.long_mem.purge_except(obj_keep_idx) + self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx} + + if not self.work_mem.engaged(): + # everything is removed! + self.engaged = False + + def compress_features(self, bucket_id: int) -> None: + + # perform memory consolidation + prototype_key, prototype_value, prototype_shrinkage = self.consolidation( + *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens)) + + # remove consolidated working memory + self.work_mem.sieve_by_range(bucket_id, + 0, + -self.min_work_tokens, + min_size=self.min_work_tokens) + + # add to long-term memory + self.long_mem.add(prototype_key, + prototype_value, + prototype_shrinkage, + selection=None, + supposed_bucket_id=bucket_id) + + def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor, + candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor], + usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # find the indices with max usage + bs = candidate_key.shape[0] + assert bs in [1, 2] + + prototype_key = [] + prototype_selection = [] + for bi in range(bs): + _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True) + prototype_indices = max_usage_indices.flatten() + prototype_key.append(candidate_key[bi, :, prototype_indices]) + prototype_selection.append(candidate_selection[bi, :, prototype_indices]) + prototype_key = torch.stack(prototype_key, dim=0) + prototype_selection = torch.stack(prototype_selection, dim=0) + """ + Potentiation step + """ + similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, + prototype_selection) + affinity = do_softmax(similarity) + + # readout the values + prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()} + + # readout the shrinkage term + prototype_shrinkage = self._readout(affinity, candidate_shrinkage) + + return prototype_key, prototype_value, prototype_shrinkage + + def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]): + for obj in ids: + if obj not in self.sensory: + # also initializes the sensory memory + bs, _, h, w = sample_key.shape + self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w), + device=sample_key.device) + + def update_sensory(self, sensory: torch.Tensor, ids: List[int]): + # sensory: 1*num_objects*C*H*W + for obj_id, obj in enumerate(ids): + self.sensory[obj] = sensory[:, obj_id] + + def get_sensory(self, ids: List[int]): + # returns (1/2)*num_objects*C*H*W + return self._get_sensory_by_ids(ids) + + def clear_non_permanent_memory(self): + self.work_mem.clear_non_permanent_memory() + if self.use_long_term: + self.long_mem.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.sensory = {} + + def clear_work_mem(self): + self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term, + save_usage=self.use_long_term) + + def clear_obj_mem(self): + self.obj_v = {} diff --git a/hf_space/third_party/matanyone/inference/object_info.py b/hf_space/third_party/matanyone/inference/object_info.py new file mode 100644 index 0000000000000000000000000000000000000000..be4f0b97d2d5e5e5c3b8d9e06f1865096a813528 --- /dev/null +++ b/hf_space/third_party/matanyone/inference/object_info.py @@ -0,0 +1,24 @@ +class ObjectInfo: + """ + Store meta information for an object + """ + def __init__(self, id: int): + self.id = id + self.poke_count = 0 # count number of detections missed + + def poke(self) -> None: + self.poke_count += 1 + + def unpoke(self) -> None: + self.poke_count = 0 + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other): + if type(other) == int: + return self.id == other + return self.id == other.id + + def __repr__(self): + return f'(ID: {self.id})' diff --git a/hf_space/third_party/matanyone/inference/object_manager.py b/hf_space/third_party/matanyone/inference/object_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1a677f5f2f0d68a06933477e5696ed3cbd0bdc9e --- /dev/null +++ b/hf_space/third_party/matanyone/inference/object_manager.py @@ -0,0 +1,149 @@ +from typing import Union, List, Dict + +import torch +from matanyone.inference.object_info import ObjectInfo + + +class ObjectManager: + """ + Object IDs are immutable. The same ID always represent the same object. + Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. + Temporary IDs start from 1. + """ + + def __init__(self): + self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} + self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} + self.obj_id_to_obj: Dict[int, ObjectInfo] = {} + + self.all_historical_object_ids: List[int] = [] + + def _recompute_obj_id_to_obj_mapping(self) -> None: + self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} + + def add_new_objects( + self, objects: Union[List[ObjectInfo], ObjectInfo, + List[int]]) -> (List[int], List[int]): + if not isinstance(objects, list): + objects = [objects] + + corresponding_tmp_ids = [] + corresponding_obj_ids = [] + for obj in objects: + if isinstance(obj, int): + obj = ObjectInfo(id=obj) + + if obj in self.obj_to_tmp_id: + # old object + corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) + corresponding_obj_ids.append(obj.id) + else: + # new object + new_obj = ObjectInfo(id=obj.id) + + # new object + new_tmp_id = len(self.obj_to_tmp_id) + 1 + self.obj_to_tmp_id[new_obj] = new_tmp_id + self.tmp_id_to_obj[new_tmp_id] = new_obj + self.all_historical_object_ids.append(new_obj.id) + corresponding_tmp_ids.append(new_tmp_id) + corresponding_obj_ids.append(new_obj.id) + + self._recompute_obj_id_to_obj_mapping() + assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) + return corresponding_tmp_ids, corresponding_obj_ids + + def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None: + # delete an object or a list of objects + # re-sort the tmp ids + if isinstance(obj_ids_to_remove, int): + obj_ids_to_remove = [obj_ids_to_remove] + + new_tmp_id = 1 + total_num_id = len(self.obj_to_tmp_id) + + local_obj_to_tmp_id = {} + local_tmp_to_obj_id = {} + + for tmp_iter in range(1, total_num_id + 1): + obj = self.tmp_id_to_obj[tmp_iter] + if obj.id not in obj_ids_to_remove: + local_obj_to_tmp_id[obj] = new_tmp_id + local_tmp_to_obj_id[new_tmp_id] = obj + new_tmp_id += 1 + + self.obj_to_tmp_id = local_obj_to_tmp_id + self.tmp_id_to_obj = local_tmp_to_obj_id + self._recompute_obj_id_to_obj_mapping() + + def purge_inactive_objects(self, + max_missed_detection_count: int) -> (bool, List[int], List[int]): + # remove tmp ids of objects that are removed + obj_id_to_be_deleted = [] + tmp_id_to_be_deleted = [] + tmp_id_to_keep = [] + obj_id_to_keep = [] + + for obj in self.obj_to_tmp_id: + if obj.poke_count > max_missed_detection_count: + obj_id_to_be_deleted.append(obj.id) + tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) + else: + tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) + obj_id_to_keep.append(obj.id) + + purge_activated = len(obj_id_to_be_deleted) > 0 + if purge_activated: + self.delete_objects(obj_id_to_be_deleted) + return purge_activated, tmp_id_to_keep, obj_id_to_keep + + def tmp_to_obj_cls(self, mask) -> torch.Tensor: + # remap tmp id cls representation to the true object id representation + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + return new_mask + + def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: + # returns the mapping in a dict format for saving it with pickle + return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} + + def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: + # turns a dict indexed by obj id into a tensor, ordered by tmp IDs + output = [] + for _, obj in self.tmp_id_to_obj.items(): + if obj.id not in obj_dict: + raise NotImplementedError + output.append(obj_dict[obj.id]) + output = torch.stack(output, dim=dim) + return output + + def make_one_hot(self, cls_mask) -> torch.Tensor: + output = [] + for _, obj in self.tmp_id_to_obj.items(): + output.append(cls_mask == obj.id) + if len(output) == 0: + output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) + else: + output = torch.stack(output, dim=0) + return output + + @property + def all_obj_ids(self) -> List[int]: + return [k.id for k in self.obj_to_tmp_id] + + @property + def num_obj(self) -> int: + return len(self.obj_to_tmp_id) + + def has_all(self, objects: List[int]) -> bool: + for obj in objects: + if obj not in self.obj_to_tmp_id: + return False + return True + + def find_object_by_id(self, obj_id) -> ObjectInfo: + return self.obj_id_to_obj[obj_id] + + def find_tmp_by_id(self, obj_id) -> int: + return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] diff --git a/hf_space/third_party/matanyone/inference/utils/__init__.py b/hf_space/third_party/matanyone/inference/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/inference/utils/args_utils.py b/hf_space/third_party/matanyone/inference/utils/args_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3f44cd50ea2a15083e6c5637d6b11c8bdd02ce --- /dev/null +++ b/hf_space/third_party/matanyone/inference/utils/args_utils.py @@ -0,0 +1,30 @@ +import logging +from omegaconf import DictConfig + +log = logging.getLogger() + + +def get_dataset_cfg(cfg: DictConfig): + dataset_name = cfg.dataset + data_cfg = cfg.datasets[dataset_name] + + potential_overrides = [ + 'image_directory', + 'mask_directory', + 'json_directory', + 'size', + 'save_all', + 'use_all_masks', + 'use_long_term', + 'mem_every', + ] + + for override in potential_overrides: + if cfg[override] is not None: + log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') + data_cfg[override] = cfg[override] + # escalte all potential overrides to the top-level config + if override in data_cfg: + cfg[override] = data_cfg[override] + + return data_cfg diff --git a/hf_space/third_party/matanyone/model/__init__.py b/hf_space/third_party/matanyone/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f91717303cb36bd217c34657602e4b7a2c3bda00 Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/__init__.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a39c23215b38d7537f9f329dc359a333db081b06 Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/aux_modules.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad1cac7c652364580b5fcffdc64a292c82fd94ec Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/big_modules.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5697f124337ecfc35c3d9cd61429978adf0caae0 Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/channel_attn.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b0b798807db0c31034631a245769ad6da41d215 Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/group_modules.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29e6f400721c66ce9d31f4f00231b29f1cbea8d7 Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/matanyone.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc b/hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e2a529beb056f6714773e8f27684bc8071c382e Binary files /dev/null and b/hf_space/third_party/matanyone/model/__pycache__/modules.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/aux_modules.py b/hf_space/third_party/matanyone/model/aux_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e88b08e68e47861c4781d3cfd0e37eb4d320f49d --- /dev/null +++ b/hf_space/third_party/matanyone/model/aux_modules.py @@ -0,0 +1,93 @@ +""" +For computing auxiliary outputs for auxiliary losses +""" +from typing import Dict +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from matanyone.model.group_modules import GConv2d +from matanyone.utils.tensor_utils import aggregate + + +class LinearPredictor(nn.Module): + def __init__(self, x_dim: int, pix_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1) + + def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + # pixel_feat: B*pix_dim*H*W + # x: B*num_objects*x_dim*H*W + num_objects = x.shape[1] + x = self.projection(x) + + pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1] + return logits + + +class DirectPredictor(nn.Module): + def __init__(self, x_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: B*num_objects*x_dim*H*W + logits = self.projection(x).squeeze(2) + return logits + + +class AuxComputer(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + use_sensory_aux = cfg.model.aux_loss.sensory.enabled + self.use_query_aux = cfg.model.aux_loss.query.enabled + self.use_sensory_aux = use_sensory_aux + + sensory_dim = cfg.model.sensory_dim + embed_dim = cfg.model.embed_dim + + if use_sensory_aux: + self.sensory_aux = LinearPredictor(sensory_dim, embed_dim) + + def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + logits = aggregate(prob, dim=1) + return logits + + def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor], + selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]: + sensory = aux_input['sensory'] + q_logits = aux_input['q_logits'] + + aux_output = {} + aux_output['attn_mask'] = aux_input['attn_mask'] + + if self.use_sensory_aux: + # B*num_objects*H*W + logits = self.sensory_aux(pix_feat, sensory) + aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector) + if self.use_query_aux: + # B*num_objects*num_levels*H*W + aux_output['q_logits'] = self._aggregate_with_selector( + torch.stack(q_logits, dim=2), + selector.unsqueeze(2) if selector is not None else None) + + return aux_output + + def compute_mask(self, aux_input: Dict[str, torch.Tensor], + selector: torch.Tensor) -> Dict[str, torch.Tensor]: + # sensory = aux_input['sensory'] + q_logits = aux_input['q_logits'] + + aux_output = {} + + # B*num_objects*num_levels*H*W + aux_output['q_logits'] = self._aggregate_with_selector( + torch.stack(q_logits, dim=2), + selector.unsqueeze(2) if selector is not None else None) + + return aux_output diff --git a/hf_space/third_party/matanyone/model/big_modules.py b/hf_space/third_party/matanyone/model/big_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8cffb22ac04eb77f8d169dcaf7c087be19b53a00 --- /dev/null +++ b/hf_space/third_party/matanyone/model/big_modules.py @@ -0,0 +1,365 @@ +""" +big_modules.py - This file stores higher-level network blocks. + +x - usually denotes features that are shared between objects. +g - usually denotes features that are not shared between objects + with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). + +The trailing number of a variable usually denotes the stride +""" + +from typing import Iterable +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F + +from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d +from matanyone.model.utils import resnet +from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock + +class UncertPred(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) + self.bn2 = nn.BatchNorm2d(32) + self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) + + def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): + last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area') + x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1) + x = self.conv1x1_v2(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv3x3(x) + x = self.bn2(x) + x = self.relu(x) + x = self.conv3x3_out(x) + return x + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + +class PixelEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type + # if model_cfg.pretrained_resnet is set in the model_cfg we get the value + # else default to True + is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) + if self.is_resnet: + if model_cfg.pixel_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=is_pretrained_resnet) + elif model_cfg.pixel_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=is_pretrained_resnet) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.res2 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): + f1 = x + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + f2 = x + x = self.maxpool(x) + f4 = self.res2(x) + f8 = self.layer2(f4) + f16 = self.layer3(f8) + + return f16, f8, f4, f2, f1 + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class KeyProjection(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + in_dim = model_cfg.pixel_encoder.ms_dims[0] + mid_dim = model_cfg.pixel_dim + key_dim = model_cfg.key_dim + + self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) + self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + # shrinkage + self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) + # selection + self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x: torch.Tensor, *, need_s: bool, + need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): + x = self.pix_feat_proj(x) + shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None + selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None + + return self.key_proj(x), shrinkage, selection + + +class MaskEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + pixel_dim = model_cfg.pixel_dim + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + final_dim = model_cfg.mask_encoder.final_dim + + self.single_object = single_object + extra_dim = 1 if single_object else 2 + + # if model_cfg.pretrained_resnet is set in the model_cfg we get the value + # else default to True + is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) + if model_cfg.mask_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim) + elif model_cfg.mask_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.layer1 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + + self.distributor = MainToGroupDistributor() + self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) + + self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) + + def forward(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + sensory: torch.Tensor, + masks: torch.Tensor, + others: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): + # ms_features are from the key encoder + # we only use the first one (lowest resolution), following XMem + if self.single_object: + g = masks.unsqueeze(2) + else: + g = torch.stack([masks, others], dim=2) + + g = self.distributor(image, g) + + batch_size, num_objects = g.shape[:2] + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if deep_update: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_g = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + g_chunk = g + else: + g_chunk = g[:, i:i + chunk_size] + actual_chunk_size = g_chunk.shape[1] + g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) + + g_chunk = self.conv1(g_chunk) + g_chunk = self.bn1(g_chunk) # 1/2, 64 + g_chunk = self.maxpool(g_chunk) # 1/4, 64 + g_chunk = self.relu(g_chunk) + + g_chunk = self.layer1(g_chunk) # 1/4 + g_chunk = self.layer2(g_chunk) # 1/8 + g_chunk = self.layer3(g_chunk) # 1/16 + + g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) + g_chunk = self.fuser(pix_feat, g_chunk) + all_g.append(g_chunk) + if deep_update: + if fast_path: + new_sensory = self.sensory_update(g_chunk, sensory) + else: + new_sensory[:, i:i + chunk_size] = self.sensory_update( + g_chunk, sensory[:, i:i + chunk_size]) + g = torch.cat(all_g, dim=1) + + return g, new_sensory + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class PixelFeatureFuser(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + pixel_dim = model_cfg.pixel_dim + embed_dim = model_cfg.embed_dim + self.single_object = single_object + + self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) + if self.single_object: + self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) + else: + self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) + + def forward(self, + pix_feat: torch.Tensor, + pixel_memory: torch.Tensor, + sensory_memory: torch.Tensor, + last_mask: torch.Tensor, + last_others: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + batch_size, num_objects = pixel_memory.shape[:2] + + if self.single_object: + last_mask = last_mask.unsqueeze(2) + else: + last_mask = torch.stack([last_mask, last_others], dim=2) + + if chunk_size < 1: + chunk_size = num_objects + + # chunk-by-chunk inference + all_p16 = [] + for i in range(0, num_objects, chunk_size): + sensory_readout = self.sensory_compress( + torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) + p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout + p16 = self.fuser(pix_feat, p16) + all_p16.append(p16) + p16 = torch.cat(all_p16, dim=1) + + return p16 + + +class MaskDecoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + embed_dim = model_cfg.embed_dim + sensory_dim = model_cfg.sensory_dim + ms_image_dims = model_cfg.pixel_encoder.ms_dims + up_dims = model_cfg.mask_decoder.up_dims + + assert embed_dim == up_dims[0] + + self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim, + sensory_dim) + + self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) + self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) + self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) + # newly add for alpha matte + self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3]) + self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4]) + + self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) + self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) + + def forward(self, + ms_image_feat: Iterable[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + chunk_size: int = -1, + update_sensory: bool = True, + seg_pass: bool = False, + last_mask=None, + sigmoid_residual=False) -> (torch.Tensor, torch.Tensor): + + batch_size, num_objects = memory_readout.shape[:2] + f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:]) + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if update_sensory: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_logits = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + p16 = memory_readout + else: + p16 = memory_readout[:, i:i + chunk_size] + actual_chunk_size = p16.shape[1] + + p8 = self.up_16_8(p16, f8) + p4 = self.up_8_4(p8, f4) + p2 = self.up_4_2(p4, f2) + p1 = self.up_2_1(p2, f1) + with torch.amp.autocast("cuda",enabled=False): + if seg_pass: + if last_mask is not None: + res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + if sigmoid_residual: + res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask + logits = last_mask + res + else: + logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + else: + if last_mask is not None: + res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + if sigmoid_residual: + res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask + logits = last_mask + res + else: + logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) + ## SensoryUpdater_fullscale + if update_sensory: + p1 = torch.cat( + [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) + if fast_path: + new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory) + else: + new_sensory[:, + i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1], + sensory[:, + i:i + chunk_size]) + all_logits.append(logits) + logits = torch.cat(all_logits, dim=0) + logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) + + return new_sensory, logits diff --git a/hf_space/third_party/matanyone/model/channel_attn.py b/hf_space/third_party/matanyone/model/channel_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d30f74746e50c50c053aa2fc600eb0125419c54f --- /dev/null +++ b/hf_space/third_party/matanyone/model/channel_attn.py @@ -0,0 +1,39 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CAResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, residual: bool = True): + super().__init__() + self.residual = residual + self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) + + t = int((abs(math.log2(out_dim)) + 1) // 2) + k = t if t % 2 else t + 1 + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False) + + if self.residual: + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.conv1(F.relu(x)) + x = self.conv2(F.relu(x)) + + b, c = x.shape[:2] + w = self.pool(x).view(b, 1, c) + w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1 + + if self.residual: + x = x * w + self.downsample(r) + else: + x = x * w + + return x diff --git a/hf_space/third_party/matanyone/model/group_modules.py b/hf_space/third_party/matanyone/model/group_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b877b5b98aeab8be980c29965cd08f85fc4dfa9f --- /dev/null +++ b/hf_space/third_party/matanyone/model/group_modules.py @@ -0,0 +1,126 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from matanyone.model.channel_attn import CAResBlock + +def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, + align_corners: bool) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = F.interpolate(g.flatten(start_dim=0, end_dim=1), + scale_factor=ratio, + mode=mode, + align_corners=align_corners) + g = g.view(batch_size, num_objects, *g.shape[1:]) + return g + + +def upsample_groups(g: torch.Tensor, + ratio: float = 2, + mode: str = 'bilinear', + align_corners: bool = False) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +def downsample_groups(g: torch.Tensor, + ratio: float = 1 / 2, + mode: str = 'area', + align_corners: bool = None) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +class GConv2d(nn.Conv2d): + def forward(self, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = super().forward(g.flatten(start_dim=0, end_dim=1)) + return g.view(batch_size, num_objects, *g.shape[1:]) + + +class GroupResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) + + self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g: torch.Tensor) -> torch.Tensor: + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + g = self.downsample(g) + + return out_g + g + + +class MainToGroupDistributor(nn.Module): + def __init__(self, + x_transform: Optional[nn.Module] = None, + g_transform: Optional[nn.Module] = None, + method: str = 'cat', + reverse_order: bool = False): + super().__init__() + + self.x_transform = x_transform + self.g_transform = g_transform + self.method = method + self.reverse_order = reverse_order + + def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: + num_objects = g.shape[1] + + if self.x_transform is not None: + x = self.x_transform(x) + + if self.g_transform is not None: + g = self.g_transform(g) + + if not skip_expand: + x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + if self.method == 'cat': + if self.reverse_order: + g = torch.cat([g, x], 2) + else: + g = torch.cat([x, g], 2) + elif self.method == 'add': + g = x + g + elif self.method == 'mulcat': + g = torch.cat([x * g, g], dim=2) + elif self.method == 'muladd': + g = x * g + g + else: + raise NotImplementedError + + return g + + +class GroupFeatureFusionBlock(nn.Module): + def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): + super().__init__() + + x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) + g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) + + self.distributor = MainToGroupDistributor(x_transform=x_transform, + g_transform=g_transform, + method='add') + self.block1 = CAResBlock(out_dim, out_dim) + self.block2 = CAResBlock(out_dim, out_dim) + + def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + + g = self.distributor(x, g) + + g = g.flatten(start_dim=0, end_dim=1) + + g = self.block1(g) + g = self.block2(g) + + g = g.view(batch_size, num_objects, *g.shape[1:]) + + return g diff --git a/hf_space/third_party/matanyone/model/matanyone.py b/hf_space/third_party/matanyone/model/matanyone.py new file mode 100644 index 0000000000000000000000000000000000000000..e271f0f87a425dbe2cdce8d434aab476e0cafa68 --- /dev/null +++ b/hf_space/third_party/matanyone/model/matanyone.py @@ -0,0 +1,333 @@ +from typing import List, Dict, Iterable, Tuple +import logging +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf +from huggingface_hub import PyTorchModelHubMixin + +from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder +from matanyone.model.aux_modules import AuxComputer +from matanyone.model.utils.memory_utils import get_affinity, readout +from matanyone.model.transformer.object_transformer import QueryTransformer +from matanyone.model.transformer.object_summarizer import ObjectSummarizer +from matanyone.utils.tensor_utils import aggregate + +log = logging.getLogger() +class MatAnyone(nn.Module, + PyTorchModelHubMixin, + library_name="matanyone", + repo_url="https://github.com/pq-yang/MatAnyone", + coders={ + DictConfig: ( + lambda x: OmegaConf.to_container(x), + lambda data: OmegaConf.create(data), + ) + }, + ): + + def __init__(self, cfg: DictConfig, *, single_object=False): + super().__init__() + self.cfg = cfg + model_cfg = cfg.model + self.ms_dims = model_cfg.pixel_encoder.ms_dims + self.key_dim = model_cfg.key_dim + self.value_dim = model_cfg.value_dim + self.sensory_dim = model_cfg.sensory_dim + self.pixel_dim = model_cfg.pixel_dim + self.embed_dim = model_cfg.embed_dim + self.single_object = single_object + + log.info(f'Single object: {self.single_object}') + + self.pixel_encoder = PixelEncoder(model_cfg) + self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1) + self.key_proj = KeyProjection(model_cfg) + self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object) + self.mask_decoder = MaskDecoder(model_cfg) + self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object) + self.object_transformer = QueryTransformer(model_cfg) + self.object_summarizer = ObjectSummarizer(model_cfg) + self.aux_computer = AuxComputer(cfg) + self.temp_sparity = UncertPred(model_cfg) + + self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False) + + def _get_others(self, masks: torch.Tensor) -> torch.Tensor: + # for each object, return the sum of masks of all other objects + if self.single_object: + return None + + num_objects = masks.shape[1] + if num_objects >= 1: + others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1) + else: + others = torch.zeros_like(masks) + return others + + def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): + logits = self.temp_sparity(last_frame_feat=last_pix_feat, + cur_frame_feat=cur_pix_feat, + last_mask=last_mask, + mem_val_diff=mem_val_diff) + + prob = torch.sigmoid(logits) + mask = (prob > 0) + 0 + + uncert_output = {"logits": logits, + "prob": prob, + "mask": mask} + + return uncert_output + + def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore + image = (image - self.pixel_mean) / self.pixel_std + ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1 + return ms_image_feat, self.pix_feat_proj(ms_image_feat[0]) + + def encode_mask( + self, + image: torch.Tensor, + ms_features: List[torch.Tensor], + sensory: torch.Tensor, + masks: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1, + need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + image = (image - self.pixel_mean) / self.pixel_std + others = self._get_others(masks) + mask_value, new_sensory = self.mask_encoder(image, + ms_features, + sensory, + masks, + others, + deep_update=deep_update, + chunk_size=chunk_size) + object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights) + return mask_value, new_sensory, object_summaries, object_logits + + def transform_key(self, + final_pix_feat: torch.Tensor, + *, + need_sk: bool = True, + need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek) + return key, shrinkage, selection + + # Used in training only. + # This step is replaced by MemoryManager in test time + def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor, + memory_key: torch.Tensor, memory_shrinkage: torch.Tensor, + msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor, + sensory: torch.Tensor, last_mask: torch.Tensor, + selector: torch.Tensor, uncert_output=None, seg_pass=False, + last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + msk_value : B * num_objects * CV * T * H * W + obj_memory : B * num_objects * T * num_summaries * C + pixel_feature : B * C * H * W + """ + batch_size, num_objects = msk_value.shape[:2] + + uncert_mask = uncert_output["mask"] if uncert_output is not None else None + + # read using visual attention + with torch.amp.autocast("cuda",enabled=False): + affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(), + query_selection.float(), uncert_mask=uncert_mask) + + msk_value = msk_value.flatten(start_dim=1, end_dim=2).float() + + # B * (num_objects*CV) * H * W + pixel_readout = readout(affinity, msk_value, uncert_mask) + pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim, + *pixel_readout.shape[-2:]) + + uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1]) + uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w + pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob) + + pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) + + + # read from query transformer + mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass) + + aux_output = { + 'sensory': sensory, + 'q_logits': aux_features['logits'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'] if aux_features else None, + } + + return mem_readout, aux_output, uncert_output + + def read_first_frame_memory(self, pixel_readout, + obj_memory: torch.Tensor, pix_feat: torch.Tensor, + sensory: torch.Tensor, last_mask: torch.Tensor, + selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + msk_value : B * num_objects * CV * T * H * W + obj_memory : B * num_objects * T * num_summaries * C + pixel_feature : B * C * H * W + """ + + pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) + + # read from query transformer + mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass) + + aux_output = { + 'sensory': sensory, + 'q_logits': aux_features['logits'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'] if aux_features else None, + } + + return mem_readout, aux_output + + def pixel_fusion(self, + pix_feat: torch.Tensor, + pixel: torch.Tensor, + sensory: torch.Tensor, + last_mask: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area') + last_others = self._get_others(last_mask) + fused = self.pixel_fuser(pix_feat, + pixel, + sensory, + last_mask, + last_others, + chunk_size=chunk_size) + return fused + + def readout_query(self, + pixel_readout, + obj_memory, + *, + selector=None, + need_weights=False, + seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + return self.object_transformer(pixel_readout, + obj_memory, + selector=selector, + need_weights=need_weights, + seg_pass=seg_pass) + + def segment(self, + ms_image_feat: List[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + selector: bool = None, + chunk_size: int = -1, + update_sensory: bool = True, + seg_pass: bool = False, + clamp_mat: bool = True, + last_mask=None, + sigmoid_residual=False, + seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + multi_scale_features is from the key encoder for skip-connection + memory_readout is from working/long-term memory + sensory is the sensory memory + last_mask is the mask from the last frame, supplementing sensory memory + selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects + during training. + """ + #### use mat head for seg data + if seg_mat: + assert seg_pass + seg_pass = False + #### + sensory, logits = self.mask_decoder(ms_image_feat, + memory_readout, + sensory, + chunk_size=chunk_size, + update_sensory=update_sensory, + seg_pass = seg_pass, + last_mask=last_mask, + sigmoid_residual=sigmoid_residual) + if seg_pass: + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + + # Softmax over all objects[] + logits = aggregate(prob, dim=1) + prob = F.softmax(logits, dim=1) + else: + if clamp_mat: + logits = logits.clamp(0.0, 1.0) + logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1) + prob = logits + + return sensory, logits, prob + + def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor], + selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]: + return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None: + if not self.single_object: + # Map single-object weight to multi-object weight (4->5 out channels in conv1) + for k in list(src_dict.keys()): + if k == 'mask_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif k == 'pixel_fuser.sensory_compress.weight': + if src_dict[k].shape[1] == self.sensory_dim + 1: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif self.single_object: + """ + If the model is multiple-object and we are training in single-object, + we strip the last channel of conv1. + This is not supposed to happen in standard training except when users are trying to + finetune a trained model with single object datasets. + """ + if src_dict['mask_encoder.conv1.weight'].shape[1] == 5: + log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.' + 'This is not supposed to happen in standard training.') + src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1] + src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1] + + for k in src_dict: + if k not in self.state_dict(): + log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!') + for k in self.state_dict(): + if k not in src_dict: + log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!') + + self.load_state_dict(src_dict, strict=False) + + @property + def device(self) -> torch.device: + return self.pixel_mean.device diff --git a/hf_space/third_party/matanyone/model/modules.py b/hf_space/third_party/matanyone/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bd5b66a62500111e4f428380cea457f6316dbb --- /dev/null +++ b/hf_space/third_party/matanyone/model/modules.py @@ -0,0 +1,149 @@ +from typing import List, Iterable +import torch +import torch.nn as nn +import torch.nn.functional as F + +from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups + + +class UpsampleBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): + super().__init__() + self.out_conv = ResBlock(in_dim, out_dim) + self.scale_factor = scale_factor + + def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: + g = F.interpolate(in_g, + scale_factor=self.scale_factor, + mode='bilinear') + g = self.out_conv(g) + g = g + skip_f + return g + +class MaskUpsampleBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): + super().__init__() + self.distributor = MainToGroupDistributor(method='add') + self.out_conv = GroupResBlock(in_dim, out_dim) + self.scale_factor = scale_factor + + def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: + g = upsample_groups(in_g, ratio=self.scale_factor) + g = self.distributor(skip_f, g) + g = self.out_conv(g) + return g + + +class DecoderFeatureProcessor(nn.Module): + def __init__(self, decoder_dims: List[int], out_dims: List[int]): + super().__init__() + self.transforms = nn.ModuleList([ + nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) + ]) + + def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: + outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] + return outputs + + +# @torch.jit.script +def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + # h: batch_size * num_objects * hidden_dim * h * w + # values: batch_size * num_objects * (hidden_dim*3) * h * w + dim = values.shape[2] // 3 + forget_gate = torch.sigmoid(values[:, :, :dim]) + update_gate = torch.sigmoid(values[:, :, dim:dim * 2]) + new_value = torch.tanh(values[:, :, dim * 2:]) + new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value + return new_h + + +class SensoryUpdater_fullscale(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): + super().__init__() + self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) + self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1) + self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1) + + self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \ + self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \ + self.g1_conv(downsample_groups(g[4], ratio=1/16)) + + with torch.amp.autocast("cuda",enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + +class SensoryUpdater(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): + super().__init__() + self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) + + self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + + with torch.amp.autocast("cuda",enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + + +class SensoryDeepUpdater(nn.Module): + def __init__(self, f_dim: int, sensory_dim: int): + super().__init__() + self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda",enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + + +class ResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g: torch.Tensor) -> torch.Tensor: + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + g = self.downsample(g) + + return out_g + g diff --git a/hf_space/third_party/matanyone/model/transformer/__init__.py b/hf_space/third_party/matanyone/model/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc b/hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7612db2903eec9c90e7906997e84ff313dca210 Binary files /dev/null and b/hf_space/third_party/matanyone/model/transformer/__pycache__/__init__.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc b/hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e58fc6e17060bdce02858e38b88d0dc41f26fd5e Binary files /dev/null and b/hf_space/third_party/matanyone/model/transformer/__pycache__/object_summarizer.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc b/hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e660d914ab3160b0c4eac350f25c20ad600af1 Binary files /dev/null and b/hf_space/third_party/matanyone/model/transformer/__pycache__/object_transformer.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc b/hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..562d7a0fa2597688f18962a736457dd9f020cc08 Binary files /dev/null and b/hf_space/third_party/matanyone/model/transformer/__pycache__/positional_encoding.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc b/hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8be5cf70ddfb56ca3b93ce81eb5e178de28fa8c6 Binary files /dev/null and b/hf_space/third_party/matanyone/model/transformer/__pycache__/transformer_layers.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/transformer/object_summarizer.py b/hf_space/third_party/matanyone/model/transformer/object_summarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..92a85832714501996f2c743d034a0756199f3023 --- /dev/null +++ b/hf_space/third_party/matanyone/model/transformer/object_summarizer.py @@ -0,0 +1,89 @@ +from typing import Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +import torch.nn.functional as F +from matanyone.model.transformer.positional_encoding import PositionalEncoding + + +# @torch.jit.script +def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor, + logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + # value: B*num_objects*H*W*value_dim + # logits: B*num_objects*H*W*num_summaries + # masks: B*num_objects*H*W*num_summaries: 1 if allowed + weights = logits.sigmoid() * masks + # B*num_objects*num_summaries*value_dim + sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value) + # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1 + area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1) + + # B*num_objects*num_summaries*value_dim + return sums, area + + +class ObjectSummarizer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_summarizer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_summaries = this_cfg.num_summaries + self.add_pe = this_cfg.add_pe + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + + if self.add_pe: + self.pos_enc = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature) + + self.input_proj = nn.Linear(self.value_dim, self.embed_dim) + self.feature_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.embed_dim), + ) + self.weights_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.num_summaries), + ) + + def forward(self, + masks: torch.Tensor, + value: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]): + # masks: B*num_objects*(H0)*(W0) + # value: B*num_objects*value_dim*H*W + # -> B*num_objects*H*W*value_dim + h, w = value.shape[-2:] + masks = F.interpolate(masks, size=(h, w), mode='area') + masks = masks.unsqueeze(-1) + inv_masks = 1 - masks + repeated_masks = torch.cat([ + masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + ], + dim=-1) + + value = value.permute(0, 1, 3, 4, 2) + value = self.input_proj(value) + if self.add_pe: + pe = self.pos_enc(value) + value = value + pe + + with torch.amp.autocast("cuda",enabled=False): + value = value.float() + feature = self.feature_pred(value) + logits = self.weights_pred(value) + sums, area = _weighted_pooling(repeated_masks, feature, logits) + + summaries = torch.cat([sums, area], dim=-1) + + if need_weights: + return summaries, logits + else: + return summaries, None diff --git a/hf_space/third_party/matanyone/model/transformer/object_transformer.py b/hf_space/third_party/matanyone/model/transformer/object_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6d1ba3b0fb57161e544eb846de92e17c64b4fe28 --- /dev/null +++ b/hf_space/third_party/matanyone/model/transformer/object_transformer.py @@ -0,0 +1,206 @@ +from typing import Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +from matanyone.model.group_modules import GConv2d +from matanyone.utils.tensor_utils import aggregate +from matanyone.model.transformer.positional_encoding import PositionalEncoding +from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN + + +class QueryTransformerBlock(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + self.ff_dim = this_cfg.ff_dim + + self.read_from_pixel = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) + self.self_attn = SelfAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) + self.ffn = FFN(self.embed_dim, self.ff_dim) + self.read_from_query = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, + norm=this_cfg.read_from_query.output_norm) + self.pixel_ffn = PixelFFN(self.embed_dim) + + def forward( + self, + x: torch.Tensor, + pixel: torch.Tensor, + query_pe: torch.Tensor, + pixel_pe: torch.Tensor, + attn_mask: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + # x: (bs*num_objects)*num_queries*embed_dim + # pixel: bs*num_objects*C*H*W + # query_pe: (bs*num_objects)*num_queries*embed_dim + # pixel_pe: (bs*num_objects)*(H*W)*C + # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) + + # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C + pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + x, q_weights = self.read_from_pixel(x, + pixel_flat, + query_pe, + pixel_pe, + attn_mask=attn_mask, + need_weights=need_weights) + x = self.self_attn(x, query_pe) + x = self.ffn(x) + + pixel_flat, p_weights = self.read_from_query(pixel_flat, + x, + pixel_pe, + query_pe, + need_weights=need_weights) + pixel = self.pixel_ffn(pixel, pixel_flat) + + if need_weights: + bs, num_objects, _, h, w = pixel.shape + q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) + p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, + self.num_queries, h, w) + + return x, pixel, q_weights, p_weights + + +class QueryTransformer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + + # query initialization and embedding + self.query_init = nn.Embedding(self.num_queries, self.embed_dim) + self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) + + # projection from object summaries to query initialization and embedding + self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) + self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) + + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.spatial_pe = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature, + channel_last=False, + transpose_output=True) + + # transformer blocks + self.num_blocks = this_cfg.num_blocks + self.blocks = nn.ModuleList( + QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) + self.mask_pred = nn.ModuleList( + nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) + for _ in range(self.num_blocks + 1)) + + self.act = nn.ReLU(inplace=True) + + def forward(self, + pixel: torch.Tensor, + obj_summaries: torch.Tensor, + selector: Optional[torch.Tensor] = None, + need_weights: bool = False, + seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]): + # pixel: B*num_objects*embed_dim*H*W + # obj_summaries: B*num_objects*T*num_queries*embed_dim + T = obj_summaries.shape[2] + bs, num_objects, _, H, W = pixel.shape + + # normalize object values + # the last channel is the cumulative area of the object + obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, + self.embed_dim + 1) + # sum over time + # during inference, T=1 as we already did streaming average in memory_manager + obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) + obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) + obj_values = obj_sums / (obj_area + 1e-4) + obj_init = self.summary_to_query_init(obj_values) + obj_emb = self.summary_to_query_emb(obj_values) + + # positional embeddings for object queries + query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init + query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb + + # positional embeddings for pixel features + pixel_init = self.pixel_init_proj(pixel) + pixel_emb = self.pixel_emb_proj(pixel) + pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) + pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb + + pixel = pixel_init + + # run the transformer + aux_features = {'logits': []} + + # first aux output + aux_logits = self.mask_pred[0](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) + aux_features['logits'].append(aux_logits) + for i in range(self.num_blocks): + query, pixel, q_weights, p_weights = self.blocks[i](query, + pixel, + query_emb, + pixel_pe, + attn_mask, + need_weights=need_weights) + + if self.training or i <= self.num_blocks - 1 or need_weights: + aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) + aux_features['logits'].append(aux_logits) + + aux_features['q_weights'] = q_weights # last layer only + aux_features['p_weights'] = p_weights # last layer only + + if self.training: + # no need to save all heads + aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, + self.num_queries, H, W)[:, :, 0] + + return pixel, aux_features + + def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor: + # logits: batch_size*num_objects*H*W + # selector: batch_size*num_objects*1*1 + # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) + # where True means the attention is blocked + + if selector is None: + prob = logits.sigmoid() + else: + prob = logits.sigmoid() * selector + logits = aggregate(prob, dim=1) + + is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) + foreground_mask = is_foreground.bool().flatten(start_dim=2) + inv_foreground_mask = ~foreground_mask + inv_background_mask = foreground_mask + + aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + + aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) + + aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False + + return aux_mask diff --git a/hf_space/third_party/matanyone/model/transformer/positional_encoding.py b/hf_space/third_party/matanyone/model/transformer/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b200813170c65a24da95a2ff646d759bb6e225 --- /dev/null +++ b/hf_space/third_party/matanyone/model/transformer/positional_encoding.py @@ -0,0 +1,108 @@ +# Reference: +# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py +# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py + +import math + +import numpy as np +import torch +from torch import nn + + +def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + + +class PositionalEncoding(nn.Module): + def __init__(self, + dim: int, + scale: float = math.pi * 2, + temperature: float = 10000, + normalize: bool = True, + channel_last: bool = True, + transpose_output: bool = False): + super().__init__() + dim = int(np.ceil(dim / 4) * 2) + self.dim = dim + inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.normalize = normalize + self.scale = scale + self.eps = 1e-6 + self.channel_last = channel_last + self.transpose_output = transpose_output + + self.cached_penc = None # the cache is irrespective of the number of objects + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: A 4/5d tensor of size + channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) + channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) + :return: positional encoding tensor that has the same shape as the input if the input is 4d + if the input is 5d, the output is broadcastable along the k-dimension + """ + if len(tensor.shape) != 4 and len(tensor.shape) != 5: + raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') + + if len(tensor.shape) == 5: + # take a sample from the k dimension + num_objects = tensor.shape[1] + tensor = tensor[:, 0] + else: + num_objects = None + + if self.channel_last: + batch_size, h, w, c = tensor.shape + else: + batch_size, c, h, w = tensor.shape + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + self.cached_penc = None + + pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) + pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) + if self.normalize: + pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale + pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale + + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_y = get_emb(sin_inp_y).unsqueeze(1) + emb_x = get_emb(sin_inp_x) + + emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) + emb[:, :, :self.dim] = emb_x + emb[:, :, self.dim:] = emb_y + + if not self.channel_last and self.transpose_output: + # cancelled out + pass + elif (not self.channel_last) or (self.transpose_output): + emb = emb.permute(2, 0, 1) + + self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + +if __name__ == '__main__': + pe = PositionalEncoding(8).cuda() + input = torch.ones((1, 8, 8, 8)).cuda() + output = pe(input) + # print(output) + print(output[0, :, 0, 0]) + print(output[0, :, 0, 5]) + print(output[0, 0, :, 0]) + print(output[0, 0, 0, :]) diff --git a/hf_space/third_party/matanyone/model/transformer/transformer_layers.py b/hf_space/third_party/matanyone/model/transformer/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4a0cc67e4c82626dafc7772e9e53df4540bd54 --- /dev/null +++ b/hf_space/third_party/matanyone/model/transformer/transformer_layers.py @@ -0,0 +1,161 @@ +# Modified from PyTorch nn.Transformer + +from typing import List, Callable + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from matanyone.model.channel_attn import CAResBlock + + +class SelfAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False]): + super().__init__() + self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) + self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + + def forward(self, + x: torch.Tensor, + pe: torch.Tensor, + attn_mask: bool = None, + key_padding_mask: bool = None) -> torch.Tensor: + x = self.norm(x) + if any(self.add_pe_to_qkv): + x_with_pe = x + pe + q = x_with_pe if self.add_pe_to_qkv[0] else x + k = x_with_pe if self.add_pe_to_qkv[1] else x + v = x_with_pe if self.add_pe_to_qkv[2] else x + else: + q = k = v = x + + r = x + x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + return r + self.dropout(x) + + +# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention +class CrossAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False], + residual: bool = True, + norm: bool = True): + super().__init__() + self.cross_attn = nn.MultiheadAttention(dim, + nhead, + dropout=dropout, + batch_first=batch_first) + if norm: + self.norm = nn.LayerNorm(dim) + else: + self.norm = nn.Identity() + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + self.residual = residual + + def forward(self, + x: torch.Tensor, + mem: torch.Tensor, + x_pe: torch.Tensor, + mem_pe: torch.Tensor, + attn_mask: bool = None, + *, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor): + x = self.norm(x) + if self.add_pe_to_qkv[0]: + q = x + x_pe + else: + q = x + + if any(self.add_pe_to_qkv[1:]): + mem_with_pe = mem + mem_pe + k = mem_with_pe if self.add_pe_to_qkv[1] else mem + v = mem_with_pe if self.add_pe_to_qkv[2] else mem + else: + k = v = mem + r = x + x, weights = self.cross_attn(q, + k, + v, + attn_mask=attn_mask, + need_weights=need_weights, + average_attn_weights=False) + + if self.residual: + return r + self.dropout(x), weights + else: + return self.dropout(x), weights + + +class FFN(nn.Module): + def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_ff) + self.linear2 = nn.Linear(dim_ff, dim_in) + self.norm = nn.LayerNorm(dim_in) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.norm(x) + x = self.linear2(self.activation(self.linear1(x))) + x = r + x + return x + + +class PixelFFN(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + self.conv = CAResBlock(dim, dim) + + def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: + # pixel: batch_size * num_objects * dim * H * W + # pixel_flat: (batch_size*num_objects) * (H*W) * dim + bs, num_objects, _, h, w = pixel.shape + pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) + pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() + + x = self.conv(pixel_flat) + x = x.view(bs, num_objects, self.dim, h, w) + return x + + +class OutputFFN(nn.Module): + def __init__(self, dim_in: int, dim_out: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_out) + self.linear2 = nn.Linear(dim_out, dim_out) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear2(self.activation(self.linear1(x))) + return x + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) diff --git a/hf_space/third_party/matanyone/model/utils/__init__.py b/hf_space/third_party/matanyone/model/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc b/hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b469a20cd781091a6324eca3470ddd14e8d97cad Binary files /dev/null and b/hf_space/third_party/matanyone/model/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc b/hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21db3ac012ab22fa26a5d458c42fda50dd98eb1b Binary files /dev/null and b/hf_space/third_party/matanyone/model/utils/__pycache__/memory_utils.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/utils/__pycache__/resnet.cpython-313.pyc b/hf_space/third_party/matanyone/model/utils/__pycache__/resnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..327a8c2d7d3007d34087a64164c96fdeefa9bee2 Binary files /dev/null and b/hf_space/third_party/matanyone/model/utils/__pycache__/resnet.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/model/utils/memory_utils.py b/hf_space/third_party/matanyone/model/utils/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd85ce0d3b21a12888068c28d241a77e46811e0a --- /dev/null +++ b/hf_space/third_party/matanyone/model/utils/memory_utils.py @@ -0,0 +1,107 @@ +import math +import torch +from typing import Optional, Union, Tuple + + +# @torch.jit.script +def get_similarity(mk: torch.Tensor, + ms: torch.Tensor, + qk: torch.Tensor, + qe: torch.Tensor, + add_batch_dim: bool = False, + uncert_mask = None) -> torch.Tensor: + # used for training/inference and memory reading/memory potentiation + # mk: B x CK x [N] - Memory keys + # ms: B x 1 x [N] - Memory shrinkage + # qk: B x CK x [HW/P] - Query keys + # qe: B x CK x [HW/P] - Query selection + # Dimensions in [] are flattened + # Return: B*N*HW + if add_batch_dim: + mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) + qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) + + CK = mk.shape[1] + + mk = mk.flatten(start_dim=2) + ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None + qk = qk.flatten(start_dim=2) + qe = qe.flatten(start_dim=2) if qe is not None else None + + # query token selection based on temporal sparsity + if uncert_mask is not None: + uncert_mask = uncert_mask.flatten(start_dim=2) + uncert_mask = uncert_mask.expand(-1, 64, -1) + qk = qk * uncert_mask + qe = qe * uncert_mask + + if qe is not None: + # See XMem's appendix for derivation + mk = mk.transpose(1, 2) + a_sq = (mk.pow(2) @ qe) + two_ab = 2 * (mk @ (qk * qe)) + b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) + similarity = (-a_sq + two_ab - b_sq) + else: + # similar to STCN if we don't have the selection term + a_sq = mk.pow(2).sum(1).unsqueeze(2) + two_ab = 2 * (mk.transpose(1, 2) @ qk) + similarity = (-a_sq + two_ab) + + if ms is not None: + similarity = similarity * ms / math.sqrt(CK) # B*N*HW + else: + similarity = similarity / math.sqrt(CK) # B*N*HW + + return similarity + + +def do_softmax( + similarity: torch.Tensor, + top_k: Optional[int] = None, + inplace: bool = False, + return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # normalize similarity with top-k softmax + # similarity: B x N x [HW/P] + # use inplace with care + if top_k is not None: + values, indices = torch.topk(similarity, k=top_k, dim=1) + + x_exp = values.exp_() + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + if inplace: + similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW + affinity = similarity + else: + affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW + else: + maxes = torch.max(similarity, dim=1, keepdim=True)[0] + x_exp = torch.exp(similarity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + indices = None + + if return_usage: + return affinity, affinity.sum(dim=2) + + return affinity + + +def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, + qe: torch.Tensor, uncert_mask = None) -> torch.Tensor: + # shorthand used in training with no top-k + similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask) + affinity = do_softmax(similarity) + return affinity + +def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor: + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T * H * W) + mem = torch.bmm(mo, affinity) + if uncert_mask is not None: + uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1) + mem = mem * uncert_mask + mem = mem.view(B, CV, H, W) + + return mem diff --git a/hf_space/third_party/matanyone/model/utils/parameter_groups.py b/hf_space/third_party/matanyone/model/utils/parameter_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..a248aa44fe952b2cdfca38700c40d79882b67df0 --- /dev/null +++ b/hf_space/third_party/matanyone/model/utils/parameter_groups.py @@ -0,0 +1,72 @@ +import logging + +log = logging.getLogger() + + +def get_parameter_groups(model, stage_cfg, print_log=False): + """ + Assign different weight decays and learning rates to different parameters. + Returns a parameter group which can be passed to the optimizer. + """ + weight_decay = stage_cfg.weight_decay + embed_weight_decay = stage_cfg.embed_weight_decay + backbone_lr_ratio = stage_cfg.backbone_lr_ratio + base_lr = stage_cfg.learning_rate + + backbone_params = [] + embed_params = [] + other_params = [] + + embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] + embedding_names = [e + '.weight' for e in embedding_names] + + # inspired by detectron2 + memo = set() + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # Avoid duplicating parameters + if param in memo: + continue + memo.add(param) + + if name.startswith('module'): + name = name[7:] + + inserted = False + if name.startswith('pixel_encoder.'): + backbone_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as a backbone parameter.') + else: + for e in embedding_names: + if name.endswith(e): + embed_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as an embedding parameter.') + break + + if not inserted: + other_params.append(param) + + parameter_groups = [ + { + 'params': backbone_params, + 'lr': base_lr * backbone_lr_ratio, + 'weight_decay': weight_decay + }, + { + 'params': embed_params, + 'lr': base_lr, + 'weight_decay': embed_weight_decay + }, + { + 'params': other_params, + 'lr': base_lr, + 'weight_decay': weight_decay + }, + ] + + return parameter_groups diff --git a/hf_space/third_party/matanyone/model/utils/resnet.py b/hf_space/third_party/matanyone/model/utils/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b82da357652338bad5cdd74bc71e853cd4ce03ee --- /dev/null +++ b/hf_space/third_party/matanyone/model/utils/resnet.py @@ -0,0 +1,179 @@ +""" +resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_add_extra_dim(target, source_state, extra_dim=1): + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if 'num_batches_tracked' not in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + +def resnet18(pretrained=True, extra_dim=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) + return model + + +def resnet50(pretrained=True, extra_dim=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) + return model diff --git a/hf_space/third_party/matanyone/utils/__init__.py b/hf_space/third_party/matanyone/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hf_space/third_party/matanyone/utils/__pycache__/__init__.cpython-313.pyc b/hf_space/third_party/matanyone/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a323181852e4b8cf1d7a024b4b133db0c3e71418 Binary files /dev/null and b/hf_space/third_party/matanyone/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/utils/__pycache__/inference_utils.cpython-313.pyc b/hf_space/third_party/matanyone/utils/__pycache__/inference_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d3784accace2a767e1473ba6bf59722e11986fc Binary files /dev/null and b/hf_space/third_party/matanyone/utils/__pycache__/inference_utils.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/utils/__pycache__/tensor_utils.cpython-313.pyc b/hf_space/third_party/matanyone/utils/__pycache__/tensor_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4c6a9661a3eefeaf632e27603597f0fd57d7a2a Binary files /dev/null and b/hf_space/third_party/matanyone/utils/__pycache__/tensor_utils.cpython-313.pyc differ diff --git a/hf_space/third_party/matanyone/utils/get_default_model.py b/hf_space/third_party/matanyone/utils/get_default_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf5801b317281c12d4b359f2aa2e28177641670 --- /dev/null +++ b/hf_space/third_party/matanyone/utils/get_default_model.py @@ -0,0 +1,27 @@ +""" +A helper function to get a default model for quick testing +""" +from omegaconf import open_dict +from hydra import compose, initialize + +import torch +from matanyone.model.matanyone import MatAnyone + +def get_matanyone_model(ckpt_path, device=None) -> MatAnyone: + initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config") + cfg = compose(config_name="eval_matanyone_config") + + with open_dict(cfg): + cfg['weights'] = ckpt_path + + # Load the network weights + if device is not None: + matanyone = MatAnyone(cfg, single_object=True).to(device).eval() + model_weights = torch.load(cfg.weights, map_location=device) + else: # if device is not specified, `.cuda()` by default + matanyone = MatAnyone(cfg, single_object=True).cuda().eval() + model_weights = torch.load(cfg.weights) + + matanyone.load_weights(model_weights) + + return matanyone diff --git a/hf_space/third_party/matanyone/utils/inference_utils.py b/hf_space/third_party/matanyone/utils/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..364007fa2d43ff7d9872d310ca5d10d2a571b484 --- /dev/null +++ b/hf_space/third_party/matanyone/utils/inference_utils.py @@ -0,0 +1,54 @@ +import os +import cv2 +import random +import numpy as np + +import torch +import torchvision + +IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG') +VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi', '.MP4', '.MOV', '.AVI') + +def read_frame_from_videos(frame_root): + if frame_root.endswith(VIDEO_EXTENSIONS): # Video file path + video_name = os.path.basename(frame_root)[:-4] + frames, _, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', output_format='TCHW') # RGB + fps = info['video_fps'] + else: + video_name = os.path.basename(frame_root) + frames = [] + fr_lst = sorted(os.listdir(frame_root)) + for fr in fr_lst: + frame = cv2.imread(os.path.join(frame_root, fr))[...,[2,1,0]] # RGB, HWC + frames.append(frame) + fps = 24 # default + frames = torch.Tensor(np.array(frames)).permute(0, 3, 1, 2).contiguous() # TCHW + + length = frames.shape[0] + + return frames, fps, length, video_name + +def get_video_paths(input_root): + video_paths = [] + for root, _, files in os.walk(input_root): + for file in files: + if file.lower().endswith(VIDEO_EXTENSIONS): + video_paths.append(os.path.join(root, file)) + return sorted(video_paths) + +def str_to_list(value): + return list(map(int, value.split(','))) + +def gen_dilate(alpha, min_kernel_size, max_kernel_size): + kernel_size = random.randint(min_kernel_size, max_kernel_size) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) + fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) + dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255 + return dilate.astype(np.float32) + +def gen_erosion(alpha, min_kernel_size, max_kernel_size): + kernel_size = random.randint(min_kernel_size, max_kernel_size) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) + fg = np.array(np.equal(alpha, 255).astype(np.float32)) + erode = cv2.erode(fg, kernel, iterations=1)*255 + return erode.astype(np.float32) diff --git a/hf_space/third_party/matanyone/utils/tensor_utils.py b/hf_space/third_party/matanyone/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..47768a9222e793c9797a4634431d34f89106b285 --- /dev/null +++ b/hf_space/third_party/matanyone/utils/tensor_utils.py @@ -0,0 +1,62 @@ +from typing import List, Iterable +import torch +import torch.nn.functional as F + + +# STM +def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): + h, w = in_img.shape[-2:] + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) + lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + + +def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: + if len(img.shape) == 4: + if pad[2] + pad[3] > 0: + img = img[:, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, pad[0]:-pad[1]] + elif len(img.shape) == 3: + if pad[2] + pad[3] > 0: + img = img[:, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, pad[0]:-pad[1]] + elif len(img.shape) == 5: + if pad[2] + pad[3] > 0: + img = img[:, :, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, :, pad[0]:-pad[1]] + else: + raise NotImplementedError + return img + + +# @torch.jit.script +def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: + with torch.amp.autocast("cuda",enabled=False): + prob = prob.float() + new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], + dim).clamp(1e-7, 1 - 1e-7) + logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf) + + return logits + + +# @torch.jit.script +def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: + # cls_gt: B*1*H*W + B, _, H, W = cls_gt.shape + one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) + return one_hot