import random import numpy as np import timm import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import torchvision.transforms as T from PIL import Image from timm.models.layers import to_2tuple, DropPath from timm.models.vision_transformer import Mlp, PatchEmbed, Block import os class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_h_size, dtype=float) grid_w = np.arange(grid_w_size, dtype=float) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_w_size, grid_h_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2. omega = 1. / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.norm1_a = norm_layer(dim) self.norm1_v = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.norm2_a = norm_layer(dim) self.norm2_v = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, modality=None): if modality == None: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) elif modality == 'a': x = x + self.drop_path(self.attn(self.norm1_a(x))) x = x + self.drop_path(self.mlp(self.norm2_a(x))) elif modality == 'v': x = x + self.drop_path(self.attn(self.norm1_v(x))) x = x + self.drop_path(self.mlp(self.norm2_v(x))) return x # our main proposed model, for pretraining only, for finetuning, use CAVMAEFT class class CAVMAE(nn.Module): """ CAV-MAE Model """ def __init__(self, img_size=224, audio_length=1024, patch_size=16, in_chans=3, embed_dim=768, modality_specific_depth=11, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, tr_pos=False): super().__init__() print('A CAV-MAE Model') print('Use norm_pix_loss: ', norm_pix_loss) print('Learnable Positional Embedding: ', tr_pos) # the encoder part # overide the timm package timm.models.vision_transformer.PatchEmbed = PatchEmbed timm.models.vision_transformer.Block = Block self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim) self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim) self.patch_embed_a.num_patches = int(audio_length * 128 / 256) print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches, self.patch_embed_v.num_patches)) self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim), requires_grad=tr_pos) # fixed sin-cos embedding self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim), requires_grad=tr_pos) # fixed sin-cos embedding # audio-branch self.blocks_a = nn.ModuleList( [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(modality_specific_depth)]) # visual-branch self.blocks_v = nn.ModuleList( [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(modality_specific_depth)]) # unified branch self.blocks_u = nn.ModuleList( [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(12 - modality_specific_depth)]) # independent normalization layer for audio, visual, and audio-visual self.norm_a, self.norm_v, self.norm = norm_layer(embed_dim), norm_layer(embed_dim), norm_layer(embed_dim) # the decoder part # Project to lower dimension for the decoder self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) # token used for masking self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_modality_a = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_modality_v = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, decoder_embed_dim), requires_grad=tr_pos) # fixed sin-cos embedding self.decoder_pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, decoder_embed_dim), requires_grad=tr_pos) # fixed sin-cos embedding self.decoder_blocks = nn.ModuleList( [Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) # project channel is different for two modality, use two projection head self.decoder_pred_a = nn.Linear(decoder_embed_dim, patch_size ** 2 * 1, bias=True) # decoder to patch self.decoder_pred_v = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch self.norm_pix_loss = norm_pix_loss self.initialize_weights() print('Audio Positional Embedding Shape:', self.pos_embed_a.shape) print('Visual Positional Embedding Shape:', self.pos_embed_v.shape) def initialize_weights(self): # initialize (and freeze) pos_embed by sin-cos embedding, opt the cls token, add by myself pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8), cls_token=False) self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0)) pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5), int(self.patch_embed_v.num_patches ** .5), cls_token=False) self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0)) decoder_pos_embed_a = get_2d_sincos_pos_embed(self.decoder_pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8), cls_token=False) self.decoder_pos_embed_a.data.copy_(torch.from_numpy(decoder_pos_embed_a).float().unsqueeze(0)) decoder_pos_embed_v = get_2d_sincos_pos_embed(self.decoder_pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5), int(self.patch_embed_v.num_patches ** .5), cls_token=False) self.decoder_pos_embed_v.data.copy_(torch.from_numpy(decoder_pos_embed_v).float().unsqueeze(0)) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed_a.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) w = self.patch_embed_v.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.modality_a, std=.02) torch.nn.init.normal_(self.modality_v, std=.02) torch.nn.init.normal_(self.decoder_modality_a, std=.02) torch.nn.init.normal_(self.decoder_modality_v, std=.02) torch.nn.init.normal_(self.mask_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def patchify(self, imgs, c, h, w, p=16): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * c)) return x def unpatchify(self, x, c, h, w, p=16): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return imgs def random_masking_unstructured(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore def random_masking_structured(self, x, mask_ratio, t=64, f=8, mode='time'): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] assert L == f * t noise = noise.reshape(N, f, t) # the audio patch is in shape [f,t], not [t,f] if mode == 'time': for i in range(N): mask_t_list = random.sample(range(t), int(t * mask_ratio)) for k in mask_t_list: noise[i, :, k] = 1.1 # large value will be removed elif mode == 'freq': for i in range(N): mask_f_list = random.sample(range(f), int(f * mask_ratio)) for k in mask_f_list: noise[i, k, :] = 1.1 # large value will be removed elif mode == 'tf': for i in range(N): mask_t_list = random.sample(range(t), int(t * mask_ratio * 0.7)) for k in mask_t_list: noise[i, :, k] = 1.1 # large value will be removed for i in range(N): mask_f_list = random.sample(range(f), int(f * mask_ratio * 0.7)) for k in mask_f_list: noise[i, k, :] = 1.1 # large value will be removed noise = noise.reshape(N, L) # sort noise for each sample, only need to manuplate these two ids_shuffle, ids_restore ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore def forward_encoder(self, a, v, mask_ratio_a, mask_ratio_v, mask_mode='unstructured'): # embed patches a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v # by default, we always use unstructured masking if mask_mode == 'unstructured': a, mask_a, ids_restore_a = self.random_masking_unstructured(a, mask_ratio_a) # in ablation study, we tried time/freq/tf masking. mode in ['freq', 'time', 'tf'] else: a, mask_a, ids_restore_a = self.random_masking_structured(a, mask_ratio_a, t=64, f=8, mode=mask_mode) # visual branch always use unstructured masking v, mask_v, ids_restore_v = self.random_masking_unstructured(v, mask_ratio_v) # audio and visual stream, independent blocks for blk in self.blocks_a: a = blk(a) for blk in self.blocks_v: v = blk(v) x = torch.cat((a, v), dim=1) # unified stream, shared blocks_u, but independent normalization layers for blk in self.blocks_u: x = blk(x) x = self.norm(x) for blk in self.blocks_u: ca = blk(a, 'a') ca = self.norm_a(ca) for blk in self.blocks_u: cv = blk(v, 'v') cv = self.norm_v(cv) return x, mask_a, ids_restore_a, mask_v, ids_restore_v, ca, cv def forward_decoder(self, x, mask_a, ids_restore_a, mask_v, ids_restore_v): x = self.decoder_embed(x) # append mask tokens to sequence # mask_tokens_a in shape [B, #a_mask_token, mask_token_dim], get the number of masked samples from mask_a[0], which is the first example of the batch, all samples should have same number of masked tokens mask_tokens_a = self.mask_token.repeat(x.shape[0], int(mask_a[0].sum()), 1) a_ = torch.cat([x[:, :self.patch_embed_a.num_patches - int(mask_a[0].sum()), :], mask_tokens_a], dim=1) # no cls token a_ = torch.gather(a_, dim=1, index=ids_restore_a.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # similar for the visual modality mask_tokens_v = self.mask_token.repeat(x.shape[0], int(mask_v[0].sum()), 1) v_ = torch.cat([x[:, self.patch_embed_a.num_patches - int(mask_a[0].sum()):, :], mask_tokens_v], dim=1) # no cls token v_ = torch.gather(v_, dim=1, index=ids_restore_v.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # concatenate audio and visual tokens x = torch.cat([a_, v_], dim=1) decoder_pos_embed = torch.cat([self.decoder_pos_embed_a, self.decoder_pos_embed_v], dim=1) x = x + decoder_pos_embed # add modality indication tokens x[:, 0:self.patch_embed_a.num_patches, :] = x[:, 0:self.patch_embed_a.num_patches, :] + self.decoder_modality_a x[:, self.patch_embed_a.num_patches:, :] = x[:, self.patch_embed_a.num_patches:, :] + self.decoder_modality_v # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x_a = self.decoder_pred_a(x[:, :self.patch_embed_a.num_patches, :]) x_v = self.decoder_pred_v(x[:, self.patch_embed_a.num_patches:, :]) # return audio and video tokens return x_a, x_v def forward_contrastive(self, audio_rep, video_rep, bidirect_contrast=False): # calculate nce loss for mean-visual representation and mean-audio representation audio_rep = torch.nn.functional.normalize(audio_rep, dim=-1) video_rep = torch.nn.functional.normalize(video_rep, dim=-1) total = torch.mm(audio_rep, torch.transpose(video_rep, 0, 1)) / 0.05 # by default we use single directional if bidirect_contrast == False: nce = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0))) c_acc = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0), torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0] return nce, c_acc else: nce_1 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0))) nce_2 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total.t(), dim=0))) c_acc_1 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0), torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0] c_acc_2 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total.t(), dim=0), dim=0), torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0] nce = (nce_1 + nce_2) / 2 c_acc = (c_acc_1 + c_acc_2) / 2 return nce, c_acc def forward_mae_loss(self, input, pred, mask, modality): if modality == 'a': # for audio, need to adjust the shape input = input.unsqueeze(1) input = input.transpose(2, 3) target = self.patchify(input, 1, int(input.shape[2] / self.patch_embed_a.patch_size[0]), int(input.shape[3] / self.patch_embed_a.patch_size[1]), 16) elif modality == 'v': target = self.patchify(input, 3, int(input.shape[2] / self.patch_embed_v.patch_size[0]), int(input.shape[3] / self.patch_embed_v.patch_size[1]), 16) # patch-wise normalization might minorly improve the classification performance, but will make the model lose inpainting function if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6) ** .5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def forward(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mae_loss_weight=1., contrast_loss_weight=0.01, mask_mode='unstructured'): # latent is used for reconstruction (mae), latent_c_{a,v} are used for contrastive learning latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs, mask_ratio_a, mask_ratio_v, mask_mode=mask_mode) # if mae loss is used if mae_loss_weight != 0: pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) loss_mae_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a') loss_mae_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v') loss_mae = mae_loss_weight * (loss_mae_a + loss_mae_v) else: loss_mae_a, loss_mae_v, loss_mae = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, device=audio.device), torch.tensor( 0.0, device=audio.device) # if contrastive loss is used if contrast_loss_weight != 0: # note this is single directional loss_c, c_acc = self.forward_contrastive(latent_c_a.mean(dim=1), latent_c_v.mean(dim=1)) loss_c = contrast_loss_weight * loss_c else: loss_c, c_acc = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, device=audio.device) loss = loss_mae + loss_c return loss, loss_mae, loss_mae_a, loss_mae_v, loss_c, mask_a, mask_v, c_acc # used only for inpainting, ignore if inpainting is not of interest def forward_inpaint(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mask_mode='unstructured'): latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs, mask_ratio_a, mask_ratio_v, mask_mode=mask_mode) pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) # [N, L, p*p*3] loss_pixel_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a') loss_pixel_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v') return pred_a, pred_v, mask_a, mask_v, loss_pixel_a, loss_pixel_v # used for retrieval, ignore if retrieval is not of interest def forward_feat(self, a, v): # embed patches a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v # the modality-specific stream for blk in self.blocks_a: a = blk(a) for blk in self.blocks_v: v = blk(v) # use modality specific normalization, for blk in self.blocks_u: a = blk(a, 'a') a = self.norm_a(a) for blk in self.blocks_u: v = blk(v, 'v') v = self.norm_v(v) return a, v def forward_audio(self, a): # embed patches a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a # the modality-specific stream for blk in self.blocks_a: a = blk(a) # use modality specific normalization, for blk in self.blocks_u: a = blk(a, 'a') a = self.norm_a(a) return a.reshape(a.shape[0], 128 // 16, 1024 // 16, 768).permute(0, 3, 1, 2) def forward_video(self, v): v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v for blk in self.blocks_v: v = blk(v) for blk in self.blocks_u: v = blk(v, 'v') v = self.norm_v(v) return v.reshape(v.shape[0], 224 // 16, 224 // 16, 768).permute(0, 3, 1, 2) # the finetuned CAV-MAE model class CAVMAEFT(nn.Module): def __init__(self, label_dim, img_size=224, audio_length=1024, patch_size=16, in_chans=3, embed_dim=768, modality_specific_depth=11, num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, tr_pos=True): super().__init__() timm.models.vision_transformer.Block = Block print('Use norm_pix_loss: ', norm_pix_loss) timm.models.vision_transformer.PatchEmbed = PatchEmbed timm.models.vision_transformer.Block = Block self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim) self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim) self.patch_embed_a.num_patches = int(audio_length * 128 / 256) print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches, self.patch_embed_v.num_patches)) self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim), requires_grad=tr_pos) # fixed sin-cos embedding self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim), requires_grad=tr_pos) # fixed sin-cos embedding self.blocks_a = nn.ModuleList( [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(modality_specific_depth)]) self.blocks_v = nn.ModuleList( [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(modality_specific_depth)]) self.blocks_u = nn.ModuleList( [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in range(12 - modality_specific_depth)]) self.norm_a = norm_layer(embed_dim) self.norm_v = norm_layer(embed_dim) self.norm = norm_layer(embed_dim) self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, label_dim)) self.initialize_weights() print('Audio Positional Embedding Shape:', self.pos_embed_a.shape) print('Visual Positional Embedding Shape:', self.pos_embed_v.shape) def get_patch_num(self, input_shape, stride): test_input = torch.zeros(1, 1, input_shape[0], input_shape[1]) test_proj = torch.nn.Conv2d(1, 4, kernel_size=(16, 16), stride=(stride, stride)) test_output = test_proj(test_input) print(test_output.shape) return test_output.shape[2], test_output[3], test_output[2] * test_output[2] def initialize_weights(self): pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8), cls_token=False) self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0)) pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5), int(self.patch_embed_v.num_patches ** .5), cls_token=False) self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0)) w = self.patch_embed_a.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) w = self.patch_embed_v.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) torch.nn.init.normal_(self.modality_a, std=.02) torch.nn.init.normal_(self.modality_v, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, a, v, mode): # multi-modal fine-tuning, our default method for fine-tuning if mode == 'multimodal': a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v for blk in self.blocks_a: a = blk(a) for blk in self.blocks_v: v = blk(v) x = torch.cat((a, v), dim=1) for blk in self.blocks_u: x = blk(x) x = self.norm(x) x = x.mean(dim=1) x = self.mlp_head(x) return x # finetune with only audio (and inference with only audio when the model is finetuned with only audio) elif mode == 'audioonly': a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a for blk in self.blocks_a: a = blk(a) # note here uses the 'a' normalization, it is used in both training and inference, so it is fine for blk in self.blocks_u: a = blk(a, 'a') a = self.norm_a(a) x = a.mean(dim=1) x = self.mlp_head(x) return x # finetune with only image (and inference with only audio when the model is finetuned with only image) elif mode == 'videoonly': v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v for blk in self.blocks_v: v = blk(v) # note here uses the 'v' normalization, it is used in both training and inference, so it is fine for blk in self.blocks_u: v = blk(v, 'v') v = self.norm_v(v) x = v.mean(dim=1) x = self.mlp_head(x) return x # used in case that the model is finetuned with both modality, but in inference only audio is given elif mode == 'missingaudioonly': a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a for blk in self.blocks_a: a = blk(a) # two forward passes to the block_u, one with modality-specific normalization, another with unified normalization u = a for blk in self.blocks_u: u = blk(u) # note here use unified normalization u = self.norm(u) u = u.mean(dim=1) for blk in self.blocks_u: a = blk(a, 'a') # note here use modality-specific normalization a = self.norm_a(a) a = a.mean(dim=1) # average the output of the two forward passes x = (u + a) / 2 x = self.mlp_head(x) return x # used in case that the model is fine-tuned with both modality, but in inference only image is given elif mode == 'missingvideoonly': v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v for blk in self.blocks_v: v = blk(v) # two forward passes to the block_u, one with modality-specific normalization, another with unified normalization u = v for blk in self.blocks_u: u = blk(u) # note here use unified normalization u = self.norm(u) u = u.mean(dim=1) for blk in self.blocks_u: v = blk(v, 'v') # note here use modality-specific normalization v = self.norm_v(v) v = v.mean(dim=1) # average the output of the two forward passes x = (u + v) / 2 x = self.mlp_head(x) return x # for retrieval def forward_feat(self, a, v, mode='av'): # return both audio and visual if mode == 'av': a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a v = self.patch_embed_v(v) v = v + self.pos_embed_v v = v + self.modality_v for blk in self.blocks_a: a = blk(a) for blk in self.blocks_v: v = blk(v) for blk in self.blocks_u: a = blk(a, 'a') a = self.norm_a(a) for blk in self.blocks_u: v = blk(v, 'v') v = self.norm_v(v) return a, v # return only audio if mode == 'a': a = a.unsqueeze(1) a = a.transpose(2, 3) a = self.patch_embed_a(a) a = a + self.pos_embed_a a = a + self.modality_a for blk in self.blocks_a: a = blk(a) for blk in self.blocks_u: a = blk(a, 'a') a = self.norm_a(a) return a def _wav2fbank(filename): waveform, sr = torchaudio.load(filename) waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=16000 ) waveform = waveform - waveform.mean() waveform print(sr) fbank = torchaudio.compliance.kaldi.fbank( waveform, htk_compat=True, sample_frequency=sr, use_energy=False, window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10) target_length = 1024 n_frames = fbank.shape[0] p = target_length - n_frames # cut and pad if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[0:target_length, :] return fbank def pca(image_feats_list, dim=3, fit_pca=None): from sklearn.decomposition import PCA device = image_feats_list[0].device def flatten(tensor, target_size=None): if target_size is not None and fit_pca is None: F.interpolate(tensor, (target_size, target_size), mode="bilinear") B, C, H, W = tensor.shape return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() if len(image_feats_list) > 1 and fit_pca is None: target_size = image_feats_list[0].shape[2] else: target_size = None flattened_feats = [] for feats in image_feats_list: flattened_feats.append(flatten(feats, target_size)) x = torch.cat(flattened_feats, dim=0) if fit_pca is None: fit_pca = PCA(n_components=dim).fit(x) reduced_feats = [] for feats in image_feats_list: x_red = torch.from_numpy(fit_pca.transform(flatten(feats))) x_red -= x_red.min(dim=0, keepdim=True).values x_red /= x_red.max(dim=0, keepdim=True).values B, C, H, W = feats.shape reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) return reduced_feats, fit_pca class CAVMAEAudioFeaturizer(nn.Module): def __init__(self, output_path, model_name="base", model=None): super().__init__() if model is not None: self.model = model else: if model_name == "base": model_path = os.path.join(output_path, 'models/audio_model.21.pth') else: raise ValueError(f"Unknown model type {model_name}") audio_model = CAVMAE( audio_length=1024, modality_specific_depth=11, norm_pix_loss=True, tr_pos=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mdl_weight = torch.load(model_path, map_location=device) audio_model = torch.nn.DataParallel(audio_model) audio_model.load_state_dict(mdl_weight, strict=True) self.model = audio_model.module.cuda() def forward(self, audio, include_cls): cls_token = None patch_tokens = self.model.forward_audio(audio.squeeze(1)) if include_cls: return patch_tokens, cls_token else: return patch_tokens class CAVMAEImageFeaturizer(nn.Module): def __init__(self, output_path, model=None, model_name="base"): super().__init__() if model is not None: self.model: CAVMAE = model else: if model_name == "base": model_path = os.path.join(output_path, 'models/audio_model.21.pth') else: raise ValueError(f"Unknown model type {model_name}") audio_model = CAVMAE( audio_length=1024, modality_specific_depth=11, norm_pix_loss=True, tr_pos=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mdl_weight = torch.load(model_path, map_location=device) audio_model = torch.nn.DataParallel(audio_model) audio_model.load_state_dict(mdl_weight, strict=True) self.model: CAVMAE = audio_model.module.cuda() def forward(self, image, include_cls): cls_token = None patch_tokens = self.model.forward_video(image) if include_cls: return patch_tokens, cls_token else: return patch_tokens if __name__ == "__main__": model_path = os.path.join("../../", 'models/audio_model.21.pth') audio_model = CAVMAE( audio_length=1024, modality_specific_depth=11, norm_pix_loss=True, tr_pos=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mdl_weight = torch.load(model_path, map_location=device) audio_model = torch.nn.DataParallel(audio_model) audio_model.load_state_dict(mdl_weight, strict=True) model: CAVMAE = audio_model.module.cuda() image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"] audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"] images = [] audios = [] for image_path in image_paths: image = Image.open(image_path).convert("RGB") preprocess = T.Compose([ T.Resize(224, interpolation=Image.BICUBIC), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250] )]) images.append(preprocess(image).unsqueeze(0).cuda()) for audio_path in audio_paths: a = _wav2fbank(audio_path).cuda().unsqueeze(0) a = (a + 5.081) / (4.4849) audios.append(a) audio_feats, image_feats = model.forward_feat( torch.cat(audios, dim=0), torch.cat(images, dim=0)) audio_feats = F.normalize(audio_feats.mean(1), dim=1) image_feats = F.normalize(image_feats.mean(1), dim=1) sims = torch.einsum("bc,dc->bd", image_feats, audio_feats) print(sims) print("here") # a_feat = F.normalize(a_feat, dim=1) # v_feat = F.normalize(v_feat, dim=1) # [red_v_feat, red_a_feat], fit_pca = pca([v_feat, a_feat]) # # [red_v_feat], fit_pca = pca([v_feat]) # [red_a_feat], fit_pca = pca([a_feat]) # # import matplotlib.pyplot as plt # # fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 5)) # ax[0].imshow(red_v_feat[0].permute(1, 2, 0).cpu()) # ax[1].imshow(red_a_feat[0].permute(1, 2, 0).cpu()) # plt.tight_layout() # plt.show() # print("here")