|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import clip
|
|
import mmengine
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from mmengine.dist import all_gather, get_rank
|
|
from mmengine.model import BaseModel
|
|
from mmengine.structures import LabelData
|
|
|
|
from mmaction.registry import MODELS
|
|
from .adapter import TransformerAdapter
|
|
|
|
|
|
class GatherLayer(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
|
|
ctx.save_for_backward(input)
|
|
output = all_gather(input)
|
|
return tuple(output)
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
|
|
input, = ctx.saved_tensors
|
|
grad_out = torch.zeros_like(input)
|
|
grad_out[:] = grads[get_rank()]
|
|
return grad_out
|
|
|
|
|
|
def text_prompt(labels_or_label_file, templates_or_template_file=None):
|
|
if isinstance(labels_or_label_file, str):
|
|
labels = mmengine.list_from_file(labels_or_label_file)
|
|
elif isinstance(labels_or_label_file, list):
|
|
labels = labels_or_label_file
|
|
else:
|
|
raise ValueError(f'`labels_or_label_file` must be `list` or `str`, '
|
|
f'but got {type(labels_or_label_file)}')
|
|
|
|
if templates_or_template_file is None:
|
|
templates = [
|
|
'a photo of action {}', 'a picture of action {}',
|
|
'Human action of {}', '{}, an action', '{} this is an action',
|
|
'{}, a video of action', 'Playing action of {}', '{}',
|
|
'Playing a kind of action, {}', 'Doing a kind of action, {}',
|
|
'Look, the human is {}', 'Can you recognize the action of {}?',
|
|
'Video classification of {}', 'A video of {}', 'The man is {}',
|
|
'The woman is {}'
|
|
]
|
|
elif isinstance(templates_or_template_file, str):
|
|
templates = mmengine.list_from_file(templates_or_template_file)
|
|
elif not mmengine.is_seq_of(templates_or_template_file, str):
|
|
raise ValueError(f'`template` must be list of `str`, `str` or `None`, '
|
|
f'but got {type(templates_or_template_file)}')
|
|
|
|
num_prompt = len(templates)
|
|
prompt = torch.cat(
|
|
[clip.tokenize(t.format(c)) for t in templates for c in labels])
|
|
return prompt, num_prompt
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ActionClip(BaseModel):
|
|
|
|
def __init__(self,
|
|
clip_arch: str,
|
|
num_adapter_segs: int,
|
|
num_adapter_layers: int = 6,
|
|
to_float32: bool = False,
|
|
labels_or_label_file: Optional[Union[List[str], str]] = None,
|
|
templates_or_template_file: Optional[Union[List[str],
|
|
str]] = None,
|
|
data_preprocessor: Optional[Dict] = None,
|
|
loss: Dict = dict(type='CrossEntropyLoss', loss_weight=0.5)):
|
|
super(ActionClip, self).__init__(data_preprocessor=data_preprocessor)
|
|
self.clip = clip.load(clip_arch, device='cpu')[0]
|
|
if to_float32:
|
|
self.clip.float()
|
|
|
|
self.adapter = TransformerAdapter(self.clip, num_adapter_segs,
|
|
num_adapter_layers)
|
|
|
|
self.loss = MODELS.build(loss)
|
|
|
|
if labels_or_label_file is not None:
|
|
self.prompt, self.num_prompt = text_prompt(
|
|
labels_or_label_file, templates_or_template_file)
|
|
|
|
def encode_video(self, video):
|
|
b, n, c, h, w = video.shape
|
|
video = video.view(-1, c, h, w)
|
|
frames_features = self.encode_image(video)
|
|
frames_features = frames_features.view(b, n, -1)
|
|
video_features = self.adapter(frames_features)
|
|
return video_features
|
|
|
|
def encode_image(self, image):
|
|
return self.clip.encode_image(image)
|
|
|
|
def encode_text(self, text):
|
|
return self.clip.encode_text(text)
|
|
|
|
def forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[List] = None,
|
|
mode: str = 'tensor'):
|
|
|
|
if mode == 'tensor':
|
|
return self.encode_video(inputs)
|
|
|
|
elif mode == 'predict':
|
|
assert hasattr(self, 'prompt'),\
|
|
'`labels_or_label_file` is required to perform prediction. '
|
|
|
|
video_features = self.encode_video(inputs)
|
|
video_features = video_features / video_features.norm(
|
|
dim=-1, keepdim=True)
|
|
|
|
bsz = len(data_samples)
|
|
num_views = video_features.shape[0] // bsz
|
|
|
|
text_features = self.encode_text(self.prompt.to(inputs.device))
|
|
text_features = text_features / text_features.norm(
|
|
dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
similarity = (100.0 * video_features @ text_features.T). \
|
|
view(bsz, num_views * self.num_prompt, -1)
|
|
|
|
cls_scores = F.softmax(similarity, dim=2).mean(dim=1)
|
|
|
|
for data_sample, score in zip(data_samples, cls_scores):
|
|
data_sample.pred_scores = LabelData(item=score)
|
|
|
|
return data_samples
|
|
|
|
elif mode == 'loss':
|
|
video_features = self.encode_video(inputs)
|
|
video_features = video_features / video_features.norm(
|
|
dim=-1, keepdim=True)
|
|
|
|
text_id = np.random.randint(
|
|
self.num_prompt, size=len(data_samples))
|
|
real_labels = [x.gt_labels.item.item() for x in data_samples]
|
|
selected_prompt = self.prompt.view(
|
|
self.num_prompt, -1,
|
|
self.prompt.shape[-1])[text_id, real_labels].to(inputs.device)
|
|
|
|
text_features = self.encode_text(selected_prompt)
|
|
text_features = text_features / text_features.norm(
|
|
dim=-1, keepdim=True)
|
|
|
|
video_features = torch.cat(
|
|
GatherLayer.apply(video_features), dim=0)
|
|
text_features = torch.cat(GatherLayer.apply(text_features), dim=0)
|
|
|
|
logit_scale = self.clip.logit_scale.exp()
|
|
logits_per_video = logit_scale * video_features @ text_features.t()
|
|
logits_per_text = logits_per_video.t()
|
|
labels = torch.arange(logits_per_video.shape[0]).to(
|
|
logit_scale.device)
|
|
|
|
sim_loss_v2t = self.loss(logits_per_video, labels)
|
|
sim_loss_t2v = self.loss(logits_per_text, labels)
|
|
|
|
losses = dict()
|
|
losses['sim_loss_v2t'] = sim_loss_v2t
|
|
losses['sim_loss_t2v'] = sim_loss_t2v
|
|
return losses
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
f'Invalid mode "{mode}". '
|
|
'Only supports `predict`, `loss` and `tensor` mode. ')
|
|
|