import time import numpy as np import torch import sys import torch.nn as nn def get_model_from_config(model_type, config): if model_type == 'mel_band_roformer': from models.mel_band_roformer import MelBandRoformer model = MelBandRoformer( **dict(config.model) ) else: print('Unknown model: {}'.format(model_type)) model = None return model def get_windowing_array(window_size, fade_size, device): fadein = torch.linspace(0, 1, fade_size) fadeout = torch.linspace(1, 0, fade_size) window = torch.ones(window_size) window[-fade_size:] *= fadeout window[:fade_size] *= fadein return window.to(device) def demix_track(config, model, mix, device, first_chunk_time=None): C = config.inference.chunk_size N = config.inference.num_overlap step = C // N fade_size = C // 10 border = C - step if mix.shape[1] > 2 * border and border > 0: mix = nn.functional.pad(mix, (border, border), mode='reflect') windowing_array = get_windowing_array(C, fade_size, device) with torch.cuda.amp.autocast(): with torch.no_grad(): if config.training.target_instrument is not None: req_shape = (1, ) + tuple(mix.shape) else: req_shape = (len(config.training.instruments),) + tuple(mix.shape) mix = mix.to(device) result = torch.zeros(req_shape, dtype=torch.float32).to(device) counter = torch.zeros(req_shape, dtype=torch.float32).to(device) i = 0 total_length = mix.shape[1] num_chunks = (total_length + step - 1) // step if first_chunk_time is None: start_time = time.time() first_chunk = True else: start_time = None first_chunk = False while i < total_length: part = mix[:, i:i + C] length = part.shape[-1] if length < C: if length > C // 2 + 1: part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') else: part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) if first_chunk and i == 0: chunk_start_time = time.time() x = model(part.unsqueeze(0))[0] window = windowing_array.clone() if i == 0: window[:fade_size] = 1 elif i + C >= total_length: window[-fade_size:] = 1 result[..., i:i+length] += x[..., :length] * window[..., :length] counter[..., i:i+length] += window[..., :length] i += step if first_chunk and i == step: chunk_time = time.time() - chunk_start_time first_chunk_time = chunk_time estimated_total_time = chunk_time * num_chunks print(f"Estimated total processing time for this track: {estimated_total_time:.2f} seconds") first_chunk = False if first_chunk_time is not None and i > step: chunks_processed = i // step time_remaining = first_chunk_time * (num_chunks - chunks_processed) sys.stdout.write(f"\rEstimated time remaining: {time_remaining:.2f} seconds") sys.stdout.flush() print() estimated_sources = result / counter estimated_sources = estimated_sources.cpu().numpy() np.nan_to_num(estimated_sources, copy=False, nan=0.0) if mix.shape[1] > 2 * border and border > 0: estimated_sources = estimated_sources[..., border:-border] if config.training.target_instrument is None: return {k: v for k, v in zip(config.training.instruments, estimated_sources)}, first_chunk_time else: return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}, first_chunk_time