Spaces:
Runtime error
Runtime error
| import pdb | |
| import sys | |
| [sys.path.append(i) for i in ['.', '..']] | |
| sys.path.append("./models/qp_vqvae") | |
| sys.path.append("./models/qp_vqvae/utils") | |
| import numpy as np | |
| import torch as t | |
| import torch.nn as nn | |
| from .qp_vqvae.encdec import Encoder, Decoder, assert_shape | |
| from .qp_vqvae.bottleneck import NoBottleneck, Bottleneck | |
| from .qp_vqvae.utils.logger import average_metrics | |
| from .qp_vqvae.utils.torch_utils import parse_args | |
| import torch.nn.functional as F | |
| args = parse_args() | |
| mydevice = t.device('cuda:' + args.gpu) | |
| def dont_update(params): | |
| for param in params: | |
| param.requires_grad = False | |
| def update(params): | |
| for param in params: | |
| param.requires_grad = True | |
| def calculate_strides(strides, downs): | |
| return [stride ** down for stride, down in zip(strides, downs)] | |
| # def _loss_fn(loss_fn, x_target, x_pred, hps): | |
| # if loss_fn == 'l1': | |
| # return t.mean(t.abs(x_pred - x_target)) / hps.bandwidth['l1'] | |
| # elif loss_fn == 'l2': | |
| # return t.mean((x_pred - x_target) ** 2) / hps.bandwidth['l2'] | |
| # elif loss_fn == 'linf': | |
| # residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1) | |
| # values, _ = t.topk(residual, hps.linf_k, dim=1) | |
| # return t.mean(values) / hps.bandwidth['l2'] | |
| # elif loss_fn == 'lmix': | |
| # loss = 0.0 | |
| # if hps.lmix_l1: | |
| # loss += hps.lmix_l1 * _loss_fn('l1', x_target, x_pred, hps) | |
| # if hps.lmix_l2: | |
| # loss += hps.lmix_l2 * _loss_fn('l2', x_target, x_pred, hps) | |
| # if hps.lmix_linf: | |
| # loss += hps.lmix_linf * _loss_fn('linf', x_target, x_pred, hps) | |
| # return loss | |
| # else: | |
| # assert False, f"Unknown loss_fn {loss_fn}" | |
| def _loss_fn(x_target, x_pred): | |
| smooth_l1_loss = nn.SmoothL1Loss(reduction='none') | |
| return smooth_l1_loss(x_pred,x_target).mean() | |
| #return t.mean(t.abs(x_pred - x_target)) | |
| class VQVAE(nn.Module): | |
| def __init__(self, hps, input_dim=72): | |
| super().__init__() | |
| self.hps = hps | |
| input_dim=hps.pose_dims | |
| input_shape = (hps.sample_length, input_dim) | |
| levels = hps.levels | |
| downs_t = hps.downs_t | |
| strides_t = hps.strides_t | |
| emb_width = hps.emb_width | |
| l_bins = hps.l_bins | |
| mu = hps.l_mu | |
| commit = hps.commit | |
| #root_weight = hps.root_weight | |
| # spectral = hps.spectral | |
| # multispectral = hps.multispectral | |
| multipliers = hps.hvqvae_multipliers | |
| use_bottleneck = hps.use_bottleneck | |
| if use_bottleneck: | |
| print('We use bottleneck!') | |
| else: | |
| print('We do not use bottleneck!') | |
| if not hasattr(hps, 'dilation_cycle'): | |
| hps.dilation_cycle = None | |
| block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ | |
| dilation_growth_rate=hps.dilation_growth_rate, \ | |
| dilation_cycle=hps.dilation_cycle, \ | |
| reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) | |
| self.sample_length = input_shape[0] | |
| x_shape, x_channels = input_shape[:-1], input_shape[-1] | |
| self.x_shape = x_shape | |
| self.downsamples = calculate_strides(strides_t, downs_t) | |
| self.hop_lengths = np.cumprod(self.downsamples) | |
| self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] | |
| self.levels = levels | |
| if multipliers is None: | |
| self.multipliers = [1] * levels | |
| else: | |
| assert len(multipliers) == levels, "Invalid number of multipliers" | |
| self.multipliers = multipliers | |
| def _block_kwargs(level): | |
| this_block_kwargs = dict(block_kwargs) | |
| this_block_kwargs["width"] *= self.multipliers[level] | |
| this_block_kwargs["depth"] *= self.multipliers[level] | |
| return this_block_kwargs | |
| encoder = lambda level: Encoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental | |
| decoder = lambda level: Decoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) | |
| self.encoders = nn.ModuleList() | |
| self.decoders = nn.ModuleList() | |
| for level in range(levels): | |
| self.encoders.append(encoder(level)) | |
| self.decoders.append(decoder(level)) | |
| if use_bottleneck: | |
| self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 | |
| else: | |
| self.bottleneck = NoBottleneck(levels) | |
| self.downs_t = downs_t | |
| self.strides_t = strides_t | |
| self.l_bins = l_bins | |
| self.commit = commit | |
| #self.root_weight = root_weight | |
| self.reg = hps.reg if hasattr(hps, 'reg') else 0 | |
| self.acc = hps.acc if hasattr(hps, 'acc') else 0 | |
| self.vel = hps.vel if hasattr(hps, 'vel') else 0 | |
| if self.reg == 0: | |
| print('No motion regularization!') | |
| # self.spectral = spectral | |
| # self.multispectral = multispectral | |
| def preprocess(self, x): | |
| # x: NTC [-1,1] -> NCT [-1,1] | |
| assert len(x.shape) == 3 | |
| x = x.permute(0,2,1).float() | |
| return x | |
| def postprocess(self, x): | |
| # x: NTC [-1,1] <- NCT [-1,1] | |
| x = x.permute(0,2,1) | |
| return x | |
| def _decode(self, zs, start_level=0, end_level=None): | |
| # Decode | |
| if end_level is None: | |
| end_level = self.levels | |
| assert len(zs) == end_level - start_level | |
| xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level) | |
| assert len(xs_quantised) == end_level - start_level | |
| # Use only lowest level | |
| decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1] | |
| x_out = decoder(x_quantised, all_levels=False) | |
| x_out = self.postprocess(x_out) | |
| return x_out | |
| def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): | |
| z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs] | |
| x_outs = [] | |
| for i in range(bs_chunks): | |
| zs_i = [z_chunk[i] for z_chunk in z_chunks] | |
| x_out = self._decode(zs_i, start_level=start_level, end_level=end_level) | |
| x_outs.append(x_out) | |
| return t.cat(x_outs, dim=0) | |
| def _encode(self, x, start_level=0, end_level=None): | |
| # Encode | |
| if end_level is None: | |
| end_level = self.levels | |
| x_in = self.preprocess(x) | |
| xs = [] | |
| for level in range(self.levels): | |
| encoder = self.encoders[level] | |
| x_out = encoder(x_in) | |
| xs.append(x_out[-1]) | |
| zs = self.bottleneck.encode(xs) | |
| return zs[start_level:end_level] | |
| def encode(self, x, start_level=0, end_level=None, bs_chunks=1): | |
| x_chunks = t.chunk(x, bs_chunks, dim=0) | |
| zs_list = [] | |
| for x_i in x_chunks: | |
| zs_i = self._encode(x_i, start_level=start_level, end_level=end_level) | |
| zs_list.append(zs_i) | |
| zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)] | |
| return zs | |
| def sample(self, n_samples): | |
| zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device=mydevice) for z_shape in self.z_shapes] | |
| return self.decode(zs) | |
| def forward(self, x): # ([256, 80, 282]) | |
| metrics = {} | |
| N = x.shape[0] | |
| # Encode/Decode | |
| x_in = self.preprocess(x) # ([256, 282, 80]) | |
| xs = [] | |
| for level in range(self.levels): | |
| encoder = self.encoders[level] | |
| x_out = encoder(x_in) | |
| xs.append(x_out[-1]) | |
| # xs[0]: (32, 512, 30) | |
| zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) #xs[0].shape=([256, 512, 5]) | |
| ''' | |
| zs[0]: (32, 30) | |
| xs_quantised[0]: (32, 512, 30) | |
| commit_losses[0]: 0.0009 | |
| quantiser_metrics[0]: | |
| fit 0.4646 | |
| pn 0.0791 | |
| entropy 5.9596 | |
| used_curr 512 | |
| usage 512 | |
| dk 0.0006 | |
| ''' | |
| x_outs = [] | |
| for level in range(self.levels): | |
| decoder = self.decoders[level] | |
| x_out = decoder(xs_quantised[level:level+1], all_levels=False) | |
| assert_shape(x_out, x_in.shape) | |
| x_outs.append(x_out) | |
| # x_outs[0]: (32, 45, 240) | |
| # Loss | |
| # def _spectral_loss(x_target, x_out, self.hps): | |
| # if hps.use_nonrelative_specloss: | |
| # sl = spectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] | |
| # else: | |
| # sl = spectral_convergence(x_target, x_out, self.hps) | |
| # sl = t.mean(sl) | |
| # return sl | |
| # def _multispectral_loss(x_target, x_out, self.hps): | |
| # sl = multispectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] | |
| # sl = t.mean(sl) | |
| # return sl | |
| recons_loss = t.zeros(()).cuda() | |
| regularization = t.zeros(()).cuda() | |
| velocity_loss = t.zeros(()).cuda() | |
| acceleration_loss = t.zeros(()).cuda() | |
| # spec_loss = t.zeros(()).to(x.device) | |
| # multispec_loss = t.zeros(()).to(x.device) | |
| # x_target = audio_postprocess(x.float(), self.hps) | |
| x_target = x.float() | |
| for level in reversed(range(self.levels)): | |
| x_out = self.postprocess(x_outs[level]) # (32, 240, 45) | |
| # x_out = audio_postprocess(x_out, self.hps) | |
| # scale_factor = t.ones(self.hps.pose_dims).to(x_target.device) | |
| # scale_factor[:3]=self.root_weight | |
| # x_target = x_target * scale_factor | |
| # x_out = x_out * scale_factor | |
| # this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps) | |
| this_recons_loss = _loss_fn(x_target, x_out) | |
| # this_spec_loss = _spectral_loss(x_target, x_out, hps) | |
| # this_multispec_loss = _multispectral_loss(x_target, x_out, hps) | |
| metrics[f'recons_loss_l{level + 1}'] = this_recons_loss | |
| # metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss | |
| # metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss | |
| recons_loss += this_recons_loss | |
| # spec_loss += this_spec_loss | |
| # multispec_loss += this_multispec_loss | |
| regularization += t.mean((x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1])**2) | |
| velocity_loss += _loss_fn( x_out[:, 1:] - x_out[:, :-1], x_target[:, 1:] - x_target[:, :-1]) | |
| acceleration_loss += _loss_fn(x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1], x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) | |
| # if not hasattr(self.) | |
| commit_loss = sum(commit_losses) | |
| # loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss | |
| # pdb.set_trace() | |
| loss = recons_loss + commit_loss * self.commit + self.reg * regularization + self.vel * velocity_loss + self.acc * acceleration_loss | |
| ''' x:-0.8474 ~ 1.1465 | |
| 0.2080 | |
| 5.5e-5 * 0.02 | |
| 0.0011 | |
| 0.0163 * 1 | |
| 0.0274 * 1 | |
| ''' | |
| encodings = F.one_hot(zs[0].reshape(-1), self.hps.l_bins).float() | |
| avg_probs = t.mean(encodings, dim=0) | |
| perplexity = t.exp(-t.sum(avg_probs * t.log(avg_probs + 1e-10))) | |
| with t.no_grad(): | |
| # sc = t.mean(spectral_convergence(x_target, x_out, hps)) | |
| # l2_loss = _loss_fn("l2", x_target, x_out, hps) | |
| l1_loss = _loss_fn(x_target, x_out) | |
| # linf_loss = _loss_fn("linf", x_target, x_out, hps) | |
| quantiser_metrics = average_metrics(quantiser_metrics) | |
| metrics.update(dict( | |
| loss = loss, | |
| recons_loss=recons_loss, | |
| # spectral_loss=spec_loss, | |
| # multispectral_loss=multispec_loss, | |
| # spectral_convergence=sc, | |
| # l2_loss=l2_loss, | |
| l1_loss=l1_loss, | |
| # linf_loss=linf_loss, | |
| commit_loss=commit_loss, | |
| regularization=regularization, | |
| velocity_loss=velocity_loss, | |
| acceleration_loss=acceleration_loss, | |
| perplexity=perplexity, | |
| **quantiser_metrics)) | |
| for key, val in metrics.items(): | |
| metrics[key] = val.detach() | |
| return { | |
| # "poses_feat":vq_latent, | |
| # "embedding_loss":embedding_loss, | |
| # "perplexity":perplexity, | |
| "rec_pose": x_out, | |
| "loss": loss, | |
| "metrics": metrics, | |
| "embedding_loss": commit_loss * self.commit, | |
| } | |
| class VQVAE_Encoder(nn.Module): | |
| def __init__(self, hps, input_dim=72): | |
| super().__init__() | |
| self.hps = hps | |
| input_dim=hps.pose_dims | |
| input_shape = (hps.sample_length, input_dim) | |
| levels = hps.levels | |
| downs_t = hps.downs_t | |
| strides_t = hps.strides_t | |
| emb_width = hps.emb_width | |
| l_bins = hps.l_bins | |
| mu = hps.l_mu | |
| commit = hps.commit | |
| # spectral = hps.spectral | |
| # multispectral = hps.multispectral | |
| multipliers = hps.hvqvae_multipliers | |
| use_bottleneck = hps.use_bottleneck | |
| if use_bottleneck: | |
| print('We use bottleneck!') | |
| else: | |
| print('We do not use bottleneck!') | |
| if not hasattr(hps, 'dilation_cycle'): | |
| hps.dilation_cycle = None | |
| block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ | |
| dilation_growth_rate=hps.dilation_growth_rate, \ | |
| dilation_cycle=hps.dilation_cycle, \ | |
| reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) | |
| self.sample_length = input_shape[0] | |
| x_shape, x_channels = input_shape[:-1], input_shape[-1] | |
| self.x_shape = x_shape | |
| self.downsamples = calculate_strides(strides_t, downs_t) | |
| self.hop_lengths = np.cumprod(self.downsamples) | |
| self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] | |
| self.levels = levels | |
| if multipliers is None: | |
| self.multipliers = [1] * levels | |
| else: | |
| assert len(multipliers) == levels, "Invalid number of multipliers" | |
| self.multipliers = multipliers | |
| def _block_kwargs(level): | |
| this_block_kwargs = dict(block_kwargs) | |
| this_block_kwargs["width"] *= self.multipliers[level] | |
| this_block_kwargs["depth"] *= self.multipliers[level] | |
| return this_block_kwargs | |
| encoder = lambda level: Encoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental | |
| decoder = lambda level: Decoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) | |
| self.encoders = nn.ModuleList() | |
| self.decoders = nn.ModuleList() | |
| for level in range(levels): | |
| self.encoders.append(encoder(level)) | |
| self.decoders.append(decoder(level)) | |
| if use_bottleneck: | |
| self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 | |
| else: | |
| self.bottleneck = NoBottleneck(levels) | |
| self.downs_t = downs_t | |
| self.strides_t = strides_t | |
| self.l_bins = l_bins | |
| self.commit = commit | |
| self.reg = hps.reg if hasattr(hps, 'reg') else 0 | |
| self.acc = hps.acc if hasattr(hps, 'acc') else 0 | |
| self.vel = hps.vel if hasattr(hps, 'vel') else 0 | |
| if self.reg == 0: | |
| print('No motion regularization!') | |
| # self.spectral = spectral | |
| # self.multispectral = multispectral | |
| def preprocess(self, x): | |
| # x: NTC [-1,1] -> NCT [-1,1] | |
| assert len(x.shape) == 3 | |
| x = x.permute(0,2,1).float() | |
| return x | |
| def postprocess(self, x): | |
| # x: NTC [-1,1] <- NCT [-1,1] | |
| x = x.permute(0,2,1) | |
| return x | |
| def sample(self, n_samples): | |
| zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device=mydevice) for z_shape in self.z_shapes] | |
| return self.decode(zs) | |
| def forward(self, x): # ([256, 80, 282]) | |
| metrics = {} | |
| N = x.shape[0] | |
| # Encode/Decode | |
| x_in = self.preprocess(x) | |
| xs = [] | |
| for level in range(self.levels): | |
| encoder = self.encoders[level] | |
| x_out = encoder(x_in) | |
| xs.append(x_out[-1]) | |
| # xs[0]: (32, 512, 30) | |
| zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) #xs[0].shape=([256, 512, 5]) | |
| return zs[0],xs[0] , xs_quantised[0] | |
| class VQVAE_Decoder(nn.Module): | |
| def __init__(self, hps, input_dim=72): | |
| super().__init__() | |
| self.hps = hps | |
| input_dim=hps.pose_dims | |
| input_shape = (hps.sample_length, input_dim) | |
| levels = hps.levels | |
| downs_t = hps.downs_t | |
| strides_t = hps.strides_t | |
| emb_width = hps.emb_width | |
| l_bins = hps.l_bins | |
| mu = hps.l_mu | |
| commit = hps.commit | |
| # spectral = hps.spectral | |
| # multispectral = hps.multispectral | |
| multipliers = hps.hvqvae_multipliers | |
| use_bottleneck = hps.use_bottleneck | |
| if use_bottleneck: | |
| print('We use bottleneck!') | |
| else: | |
| print('We do not use bottleneck!') | |
| if not hasattr(hps, 'dilation_cycle'): | |
| hps.dilation_cycle = None | |
| block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ | |
| dilation_growth_rate=hps.dilation_growth_rate, \ | |
| dilation_cycle=hps.dilation_cycle, \ | |
| reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) | |
| self.sample_length = input_shape[0] | |
| x_shape, x_channels = input_shape[:-1], input_shape[-1] | |
| self.x_shape = x_shape | |
| self.downsamples = calculate_strides(strides_t, downs_t) | |
| self.hop_lengths = np.cumprod(self.downsamples) | |
| self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] | |
| self.levels = levels | |
| if multipliers is None: | |
| self.multipliers = [1] * levels | |
| else: | |
| assert len(multipliers) == levels, "Invalid number of multipliers" | |
| self.multipliers = multipliers | |
| def _block_kwargs(level): | |
| this_block_kwargs = dict(block_kwargs) | |
| this_block_kwargs["width"] *= self.multipliers[level] | |
| this_block_kwargs["depth"] *= self.multipliers[level] | |
| return this_block_kwargs | |
| encoder = lambda level: Encoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental | |
| decoder = lambda level: Decoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) | |
| self.encoders = nn.ModuleList() | |
| self.decoders = nn.ModuleList() | |
| for level in range(levels): | |
| self.encoders.append(encoder(level)) | |
| self.decoders.append(decoder(level)) | |
| if use_bottleneck: | |
| self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 | |
| else: | |
| self.bottleneck = NoBottleneck(levels) | |
| self.downs_t = downs_t | |
| self.strides_t = strides_t | |
| self.l_bins = l_bins | |
| self.commit = commit | |
| self.reg = hps.reg if hasattr(hps, 'reg') else 0 | |
| self.acc = hps.acc if hasattr(hps, 'acc') else 0 | |
| self.vel = hps.vel if hasattr(hps, 'vel') else 0 | |
| if self.reg == 0: | |
| print('No motion regularization!') | |
| # self.spectral = spectral | |
| # self.multispectral = multispectral | |
| def preprocess(self, x): | |
| # x: NTC [-1,1] -> NCT [-1,1] | |
| assert len(x.shape) == 3 | |
| x = x.permute(0,2,1).float() | |
| return x | |
| def postprocess(self, x): | |
| # x: NTC [-1,1] <- NCT [-1,1] | |
| x = x.permute(0,2,1) | |
| return x | |
| def forward(self, xs): # ([256, 80, 282]) | |
| xs=[xs] | |
| zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) | |
| x_outs = [] | |
| for level in range(self.levels): | |
| decoder = self.decoders[level] | |
| x_out = decoder(xs_quantised[level:level+1], all_levels=False) | |
| x_outs.append(x_out) | |
| for level in reversed(range(self.levels)): | |
| x_out = self.postprocess(x_outs[level]) # (32, 240, 45) | |
| return x_out | |
| class Residual_VQVAE(nn.Module): | |
| def __init__(self, hps, input_dim=72): | |
| super().__init__() | |
| self.hps = hps | |
| input_dim=hps.pose_dims | |
| input_shape = (hps.sample_length, input_dim) | |
| levels = hps.levels | |
| downs_t = hps.downs_t | |
| strides_t = hps.strides_t | |
| emb_width = hps.emb_width | |
| l_bins = hps.l_bins | |
| mu = hps.l_mu | |
| commit = hps.commit | |
| root_weight = hps.root_weight | |
| # spectral = hps.spectral | |
| # multispectral = hps.multispectral | |
| multipliers = hps.hvqvae_multipliers | |
| use_bottleneck = hps.use_bottleneck | |
| if use_bottleneck: | |
| print('We use bottleneck!') | |
| else: | |
| print('We do not use bottleneck!') | |
| if not hasattr(hps, 'dilation_cycle'): | |
| hps.dilation_cycle = None | |
| block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ | |
| dilation_growth_rate=hps.dilation_growth_rate, \ | |
| dilation_cycle=hps.dilation_cycle, \ | |
| reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) | |
| self.sample_length = input_shape[0] | |
| x_shape, x_channels = input_shape[:-1], input_shape[-1] | |
| self.x_shape = x_shape | |
| self.downsamples = calculate_strides(strides_t, downs_t) | |
| self.hop_lengths = np.cumprod(self.downsamples) | |
| self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] | |
| self.levels = levels | |
| if multipliers is None: | |
| self.multipliers = [1] * levels | |
| else: | |
| assert len(multipliers) == levels, "Invalid number of multipliers" | |
| self.multipliers = multipliers | |
| def _block_kwargs(level): | |
| this_block_kwargs = dict(block_kwargs) | |
| this_block_kwargs["width"] *= self.multipliers[level] | |
| this_block_kwargs["depth"] *= self.multipliers[level] | |
| return this_block_kwargs | |
| encoder = lambda level: Encoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental | |
| decoder = lambda level: Decoder(x_channels, emb_width, level + 1, | |
| downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) | |
| self.encoders = nn.ModuleList() | |
| self.decoders = nn.ModuleList() | |
| for level in range(levels): | |
| self.encoders.append(encoder(level)) | |
| self.decoders.append(decoder(level)) | |
| if use_bottleneck: | |
| self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 | |
| else: | |
| self.bottleneck = NoBottleneck(levels) | |
| self.downs_t = downs_t | |
| self.strides_t = strides_t | |
| self.l_bins = l_bins | |
| self.commit = commit | |
| self.root_weight = root_weight | |
| self.reg = hps.reg if hasattr(hps, 'reg') else 0 | |
| self.acc = hps.acc if hasattr(hps, 'acc') else 0 | |
| self.vel = hps.vel if hasattr(hps, 'vel') else 0 | |
| if self.reg == 0: | |
| print('No motion regularization!') | |
| # self.spectral = spectral | |
| # self.multispectral = multispectral | |
| def preprocess(self, x): | |
| # x: NTC [-1,1] -> NCT [-1,1] | |
| assert len(x.shape) == 3 | |
| x = x.permute(0,2,1).float() | |
| return x | |
| def postprocess(self, x): | |
| # x: NTC [-1,1] <- NCT [-1,1] | |
| x = x.permute(0,2,1) | |
| return x | |
| def _decode(self, zs, start_level=0, end_level=None): | |
| # Decode | |
| if end_level is None: | |
| end_level = self.levels | |
| assert len(zs) == end_level - start_level | |
| xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level) | |
| assert len(xs_quantised) == end_level - start_level | |
| # Use only lowest level | |
| decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1] | |
| x_out = decoder(x_quantised, all_levels=False) | |
| x_out = self.postprocess(x_out) | |
| return x_out | |
| def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): | |
| z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs] | |
| x_outs = [] | |
| for i in range(bs_chunks): | |
| zs_i = [z_chunk[i] for z_chunk in z_chunks] | |
| x_out = self._decode(zs_i, start_level=start_level, end_level=end_level) | |
| x_outs.append(x_out) | |
| return t.cat(x_outs, dim=0) | |
| def _encode(self, x, start_level=0, end_level=None): | |
| # Encode | |
| if end_level is None: | |
| end_level = self.levels | |
| x_in = self.preprocess(x) | |
| xs = [] | |
| for level in range(self.levels): | |
| encoder = self.encoders[level] | |
| x_out = encoder(x_in) | |
| xs.append(x_out[-1]) | |
| zs = self.bottleneck.encode(xs) | |
| return zs[start_level:end_level] | |
| def encode(self, x, start_level=0, end_level=None, bs_chunks=1): | |
| x_chunks = t.chunk(x, bs_chunks, dim=0) | |
| zs_list = [] | |
| for x_i in x_chunks: | |
| zs_i = self._encode(x_i, start_level=start_level, end_level=end_level) | |
| zs_list.append(zs_i) | |
| zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)] | |
| return zs | |
| def sample(self, n_samples): | |
| zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device=mydevice) for z_shape in self.z_shapes] | |
| return self.decode(zs) | |
| def forward(self, x): # ([256, 80, 282]) | |
| metrics = {} | |
| N = x.shape[0] | |
| # Encode/Decode | |
| x_in = self.preprocess(x) # ([256, 282, 80]) | |
| xs = [] | |
| for level in range(self.levels): | |
| encoder = self.encoders[level] | |
| x_out = encoder(x_in) | |
| xs.append(x_out[-1]) | |
| # xs[0]: (32, 512, 30) | |
| zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) #xs[0].shape=([256, 512, 5]) | |
| ''' | |
| zs[0]: (32, 30) | |
| xs_quantised[0]: (32, 512, 30) | |
| commit_losses[0]: 0.0009 | |
| quantiser_metrics[0]: | |
| fit 0.4646 | |
| pn 0.0791 | |
| entropy 5.9596 | |
| used_curr 512 | |
| usage 512 | |
| dk 0.0006 | |
| ''' | |
| x_outs = [] | |
| for level in range(self.levels): | |
| decoder = self.decoders[level] | |
| x_out = decoder(xs_quantised[level:level+1], all_levels=False) | |
| assert_shape(x_out, x_in.shape) | |
| x_outs.append(x_out) | |
| # x_outs[0]: (32, 45, 240) | |
| # Loss | |
| # def _spectral_loss(x_target, x_out, self.hps): | |
| # if hps.use_nonrelative_specloss: | |
| # sl = spectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] | |
| # else: | |
| # sl = spectral_convergence(x_target, x_out, self.hps) | |
| # sl = t.mean(sl) | |
| # return sl | |
| # def _multispectral_loss(x_target, x_out, self.hps): | |
| # sl = multispectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] | |
| # sl = t.mean(sl) | |
| # return sl | |
| recons_loss = t.zeros(()).cuda() | |
| regularization = t.zeros(()).cuda() | |
| velocity_loss = t.zeros(()).cuda() | |
| acceleration_loss = t.zeros(()).cuda() | |
| # spec_loss = t.zeros(()).to(x.device) | |
| # multispec_loss = t.zeros(()).to(x.device) | |
| # x_target = audio_postprocess(x.float(), self.hps) | |
| x_target = x.float() | |
| for level in reversed(range(self.levels)): | |
| x_out = self.postprocess(x_outs[level]) # (32, 240, 45) | |
| # x_out = audio_postprocess(x_out, self.hps) | |
| scale_factor = t.ones(self.hps.pose_dims).to(x_target.device) | |
| scale_factor[:3]=self.root_weight | |
| x_target = x_target * scale_factor | |
| x_out = x_out * scale_factor | |
| # this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps) | |
| this_recons_loss = _loss_fn(x_target, x_out) | |
| # this_spec_loss = _spectral_loss(x_target, x_out, hps) | |
| # this_multispec_loss = _multispectral_loss(x_target, x_out, hps) | |
| metrics[f'recons_loss_l{level + 1}'] = this_recons_loss | |
| # metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss | |
| # metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss | |
| recons_loss += this_recons_loss | |
| # spec_loss += this_spec_loss | |
| # multispec_loss += this_multispec_loss | |
| regularization += t.mean((x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1])**2) | |
| velocity_loss += _loss_fn( x_out[:, 1:] - x_out[:, :-1], x_target[:, 1:] - x_target[:, :-1]) | |
| acceleration_loss += _loss_fn(x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1], x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) | |
| # if not hasattr(self.) | |
| commit_loss = sum(commit_losses) | |
| # loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss | |
| # pdb.set_trace() | |
| loss = recons_loss + commit_loss * self.commit + self.reg * regularization + self.vel * velocity_loss + self.acc * acceleration_loss | |
| ''' x:-0.8474 ~ 1.1465 | |
| 0.2080 | |
| 5.5e-5 * 0.02 | |
| 0.0011 | |
| 0.0163 * 1 | |
| 0.0274 * 1 | |
| ''' | |
| encodings = F.one_hot(zs[0].reshape(-1), self.hps.l_bins).float() | |
| avg_probs = t.mean(encodings, dim=0) | |
| perplexity = t.exp(-t.sum(avg_probs * t.log(avg_probs + 1e-10))) | |
| with t.no_grad(): | |
| # sc = t.mean(spectral_convergence(x_target, x_out, hps)) | |
| # l2_loss = _loss_fn("l2", x_target, x_out, hps) | |
| l1_loss = _loss_fn(x_target, x_out) | |
| # linf_loss = _loss_fn("linf", x_target, x_out, hps) | |
| quantiser_metrics = average_metrics(quantiser_metrics) | |
| metrics.update(dict( | |
| loss = loss, | |
| recons_loss=recons_loss, | |
| # spectral_loss=spec_loss, | |
| # multispectral_loss=multispec_loss, | |
| # spectral_convergence=sc, | |
| # l2_loss=l2_loss, | |
| l1_loss=l1_loss, | |
| # linf_loss=linf_loss, | |
| commit_loss=commit_loss, | |
| regularization=regularization, | |
| velocity_loss=velocity_loss, | |
| acceleration_loss=acceleration_loss, | |
| perplexity=perplexity, | |
| **quantiser_metrics)) | |
| for key, val in metrics.items(): | |
| metrics[key] = val.detach() | |
| return { | |
| # "poses_feat":vq_latent, | |
| # "embedding_loss":embedding_loss, | |
| # "perplexity":perplexity, | |
| "rec_pose": x_out, | |
| "loss": loss, | |
| "metrics": metrics, | |
| "embedding_loss": commit_loss * self.commit, | |
| } | |
| if __name__ == '__main__': | |
| ''' | |
| cd codebook/ | |
| python vqvae.py --config=./codebook.yml --train --no_cuda 2 --gpu 2 | |
| ''' | |
| import yaml | |
| from pprint import pprint | |
| from easydict import EasyDict | |
| with open(args.config) as f: | |
| config = yaml.safe_load(f) | |
| for k, v in vars(args).items(): | |
| config[k] = v | |
| pprint(config) | |
| config = EasyDict(config) | |
| x = t.rand(32, 40, 15 * 9).to(mydevice) | |
| model = VQVAE(config.VQVAE, 15 * 9) # n_joints * n_chanels | |
| model = nn.DataParallel(model, device_ids=[eval(i) for i in config.no_cuda]) | |
| model = model.to(mydevice) | |
| model = model.train() | |
| output, loss, metrics = model(x) | |
| pdb.set_trace() | |