Spaces:
Runtime error
Runtime error
| from models.BaseNetwork import BaseNetwork | |
| from models.transformer_base.ffn_base import FusionFeedForward | |
| from models.transformer_base.attention_flow import SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow | |
| from models.transformer_base.attention_base import TMHSA | |
| import torch | |
| import torch.nn as nn | |
| from functools import reduce | |
| import torch.nn.functional as F | |
| class Model(nn.Module): | |
| def __init__(self, config): | |
| super(Model, self).__init__() | |
| self.net = FGT(config['tw'], config['sw'], config['gd'], config['input_resolution'], config['in_channel'], | |
| config['cnum'], config['flow_inChannel'], config['flow_cnum'], config['frame_hidden'], | |
| config['flow_hidden'], config['PASSMASK'], | |
| config['numBlocks'], config['kernel_size'], config['stride'], config['padding'], | |
| config['num_head'], config['conv_type'], config['norm'], | |
| config['use_bias'], config['ape'], | |
| config['mlp_ratio'], config['drop'], config['init_weights']) | |
| def forward(self, frames, flows, masks): | |
| ret = self.net(frames, flows, masks) | |
| return ret | |
| class Encoder(nn.Module): | |
| def __init__(self, in_channels): | |
| super(Encoder, self).__init__() | |
| self.group = [1, 2, 4, 8, 1] | |
| self.layers = nn.ModuleList([ | |
| nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ]) | |
| def forward(self, x): | |
| bt, c, h, w = x.size() | |
| h, w = h // 4, w // 4 | |
| out = x | |
| for i, layer in enumerate(self.layers): | |
| if i == 8: | |
| x0 = out | |
| if i > 8 and i % 2 == 0: | |
| g = self.group[(i - 8) // 2] | |
| x = x0.view(bt, g, -1, h, w) | |
| o = out.view(bt, g, -1, h, w) | |
| out = torch.cat([x, o], 2).view(bt, -1, h, w) | |
| out = layer(out) | |
| return out | |
| class AddPosEmb(nn.Module): | |
| def __init__(self, h, w, in_channels, out_channels): | |
| super(AddPosEmb, self).__init__() | |
| self.proj = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels) | |
| self.h, self.w = h, w | |
| def forward(self, x, h=0, w=0): | |
| B, N, C = x.shape | |
| if h == 0 and w == 0: | |
| assert N == self.h * self.w, 'Wrong input size' | |
| else: | |
| assert N == h * w, 'Wrong input size during inference' | |
| feat_token = x | |
| if h == 0 and w == 0: | |
| cnn_feat = feat_token.transpose(1, 2).view(B, C, self.h, self.w) | |
| else: | |
| cnn_feat = feat_token.transpose(1, 2).view(B, C, h, w) | |
| x = self.proj(cnn_feat) + cnn_feat | |
| x = x.flatten(2).transpose(1, 2) | |
| return x | |
| class Vec2Patch(nn.Module): | |
| def __init__(self, channel, hidden, output_size, kernel_size, stride, padding): | |
| super(Vec2Patch, self).__init__() | |
| self.relu = nn.LeakyReLU(0.2, inplace=True) | |
| c_out = reduce((lambda x, y: x * y), kernel_size) * channel | |
| self.embedding = nn.Linear(hidden, c_out) | |
| self.restore = nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding) | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| def forward(self, x, output_h=0, output_w=0): | |
| feat = self.embedding(x) | |
| feat = feat.permute(0, 2, 1) | |
| if output_h != 0 or output_w != 0: | |
| feat = F.fold(feat, output_size=(output_h, output_w), kernel_size=self.kernel_size, stride=self.stride, | |
| padding=self.padding) | |
| else: | |
| feat = self.restore(feat) | |
| return feat | |
| class TemporalTransformer(nn.Module): | |
| def __init__(self, token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, dropout, n_vecs, | |
| t2t_params): | |
| super(TemporalTransformer, self).__init__() | |
| self.attention = TMHSA(token_size=token_size, group_size=t_groupSize, d_model=frame_hidden, head=num_heads, | |
| p=dropout) | |
| self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout) | |
| self.norm1 = nn.LayerNorm(frame_hidden) | |
| self.norm2 = nn.LayerNorm(frame_hidden) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x, t, h, w, output_size): | |
| token_size = h * w | |
| s = self.norm1(x) | |
| x = x + self.dropout(self.attention(s, t, h, w)) | |
| y = self.norm2(x) | |
| x = x + self.ffn(y, token_size, output_size[0], output_size[1]) | |
| return x | |
| class SpatialTransformer(nn.Module): | |
| def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio, | |
| dropout, n_vecs, t2t_params): | |
| super(SpatialTransformer, self).__init__() | |
| self.attention = SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(token_size=token_size, window_size=s_windowSize, | |
| kernel_size=g_downSize, d_model=frame_hidden, | |
| flow_dModel=flow_hidden, head=num_heads, p=dropout) | |
| self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout) | |
| self.norm = nn.LayerNorm(frame_hidden) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x, f, t, h, w, output_size): | |
| token_size = h * w | |
| x = x + self.dropout(self.attention(x, f, t, h, w)) | |
| y = self.norm(x) | |
| x = x + self.ffn(y, token_size, output_size[0], output_size[1]) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize, | |
| mlp_ratio, | |
| dropout, n_vecs, | |
| t2t_params): | |
| super(TransformerBlock, self).__init__() | |
| self.t_transformer = TemporalTransformer(token_size=token_size, frame_hidden=frame_hidden, num_heads=num_heads, | |
| t_groupSize=t_groupSize, mlp_ratio=mlp_ratio, | |
| dropout=dropout, n_vecs=n_vecs, | |
| t2t_params=t2t_params) # temporal multi-head self attention | |
| self.s_transformer = SpatialTransformer(token_size=token_size, frame_hidden=frame_hidden, | |
| flow_hidden=flow_hidden, num_heads=num_heads, s_windowSize=s_windowSize, | |
| g_downSize=g_downSize, mlp_ratio=mlp_ratio, | |
| dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params) | |
| def forward(self, inputs): | |
| x, f, t = inputs['x'], inputs['f'], inputs['t'] | |
| h, w = inputs['h'], inputs['w'] | |
| output_size = inputs['output_size'] | |
| x = self.t_transformer(x, t, h, w, output_size) | |
| x = self.s_transformer(x, f, t, h, w, output_size) | |
| return {'x': x, 'f': f, 't': t, 'h': h, 'w': w, 'output_size': output_size} | |
| class Decoder(BaseNetwork): | |
| def __init__(self, conv_type, in_channels, out_channels, use_bias, norm=None): | |
| super(Decoder, self).__init__(conv_type) | |
| self.layer1 = self.DeconvBlock(in_channels, in_channels, kernel_size=3, padding=1, norm=norm, | |
| bias=use_bias) | |
| self.layer2 = self.ConvBlock(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1, norm=norm, | |
| bias=use_bias) | |
| self.layer3 = self.DeconvBlock(in_channels // 2, in_channels // 2, kernel_size=3, padding=1, norm=norm, | |
| bias=use_bias) | |
| self.final = self.ConvBlock(in_channels // 2, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, | |
| bias=use_bias, activation=None) | |
| def forward(self, features): | |
| feat1 = self.layer1(features) | |
| feat2 = self.layer2(feat1) | |
| feat3 = self.layer3(feat2) | |
| output = self.final(feat3) | |
| return output | |
| class FGT(BaseNetwork): | |
| def __init__(self, t_groupSize, s_windowSize, g_downSize, input_resolution, in_channels, cnum, flow_inChannel, | |
| flow_cnum, | |
| frame_hidden, flow_hidden, passmask, numBlocks, kernel_size, stride, padding, num_heads, conv_type, | |
| norm, use_bias, ape, mlp_ratio=4, drop=0, init_weights=True): | |
| super(FGT, self).__init__(conv_type) | |
| self.in_channels = in_channels | |
| self.passmask = passmask | |
| self.ape = ape | |
| self.frame_endoder = Encoder(in_channels) | |
| self.flow_encoder = nn.Sequential( | |
| nn.ReplicationPad2d(2), | |
| self.ConvBlock(flow_inChannel, flow_cnum, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=norm), | |
| self.ConvBlock(flow_cnum, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm), | |
| self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=norm), | |
| self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm) | |
| ) | |
| # patch to vector operation | |
| self.patch2vec = nn.Conv2d(cnum * 2, frame_hidden, kernel_size=kernel_size, stride=stride, padding=padding) | |
| self.f_patch2vec = nn.Conv2d(flow_cnum * 2, flow_hidden, kernel_size=kernel_size, stride=stride, | |
| padding=padding) | |
| # initialize transformer blocks for frame completion | |
| n_vecs = 1 | |
| token_size = [] | |
| output_shape = (input_resolution[0] // 4, input_resolution[1] // 4) | |
| for i, d in enumerate(kernel_size): | |
| token_nums = int((output_shape[i] + 2 * padding[i] - kernel_size[i]) / stride[i] + 1) | |
| n_vecs *= token_nums | |
| token_size.append(token_nums) | |
| # Add positional embedding to the encode features | |
| if self.ape: | |
| self.add_pos_emb = AddPosEmb(token_size[0], token_size[1], frame_hidden, frame_hidden) | |
| self.token_size = token_size | |
| # initialize transformer blocks | |
| blocks = [] | |
| t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape} | |
| for i in range(numBlocks // 2 - 1): | |
| layer = TransformerBlock(token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, | |
| g_downSize, mlp_ratio, drop, n_vecs, t2t_params) | |
| blocks.append(layer) | |
| self.first_t_transformer = TemporalTransformer(token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, | |
| drop, n_vecs, t2t_params) | |
| self.first_s_transformer = SpatialTransformer(token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, | |
| g_downSize, mlp_ratio, drop, n_vecs, t2t_params) | |
| self.transformer = nn.Sequential(*blocks) | |
| # vector to patch operation | |
| self.vec2patch = Vec2Patch(cnum * 2, frame_hidden, output_shape, kernel_size, stride, padding) | |
| # decoder | |
| self.decoder = Decoder(conv_type, cnum * 2, 3, use_bias, norm) | |
| if init_weights: | |
| self.init_weights() | |
| def forward(self, masked_frames, flows, masks): | |
| b, t, c, h, w = masked_frames.shape | |
| cf = flows.shape[2] | |
| output_shape = (h // 4, w // 4) | |
| if self.passmask: | |
| inputs = torch.cat((masked_frames, masks), dim=2) | |
| else: | |
| inputs = masked_frames | |
| inputs = inputs.view(b * t, self.in_channels, h, w) | |
| flows = flows.view(b * t, cf, h, w) | |
| enc_feats = self.frame_endoder(inputs) | |
| flow_feats = self.flow_encoder(flows) | |
| trans_feat = self.patch2vec(enc_feats) | |
| flow_patches = self.f_patch2vec(flow_feats) | |
| _, c, h, w = trans_feat.shape | |
| cf = flow_patches.shape[1] | |
| if h != self.token_size[0] or w != self.token_size[1]: | |
| new_h, new_w = h, w | |
| else: | |
| new_h, new_w = 0, 0 | |
| output_shape = (0, 0) | |
| trans_feat = trans_feat.view(b * t, c, -1).permute(0, 2, 1) | |
| flow_patches = flow_patches.view(b * t, cf, -1).permute(0, 2, 1) | |
| trans_feat = self.first_t_transformer(trans_feat, t, new_h, new_w, output_shape) | |
| trans_feat = self.add_pos_emb(trans_feat, new_h, new_w) | |
| trans_feat = self.first_s_transformer(trans_feat, flow_patches, t, new_h, new_w, output_shape) | |
| inputs_trans_feat = {'x': trans_feat, 'f': flow_patches, 't': t, 'h': new_h, 'w': new_w, | |
| 'output_size': output_shape} | |
| trans_feat = self.transformer(inputs_trans_feat)['x'] | |
| trans_feat = self.vec2patch(trans_feat, output_shape[0], output_shape[1]) | |
| enc_feats = enc_feats + trans_feat | |
| output = self.decoder(enc_feats) | |
| output = torch.tanh(output) | |
| return output | |