|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.parametrize import register_parametrization |
|
from torchcomp import ms2coef, coef2ms, db2amp, amp2db |
|
from torchaudio.transforms import Spectrogram, InverseSpectrogram |
|
|
|
from typing import List, Tuple, Union, Any, Optional, Callable |
|
import math |
|
from torch_fftconv import fft_conv1d |
|
from functools import reduce |
|
|
|
from .functional import ( |
|
compressor_expander, |
|
lowpass_biquad, |
|
highpass_biquad, |
|
equalizer_biquad, |
|
lowshelf_biquad, |
|
highshelf_biquad, |
|
lowpass_biquad_coef, |
|
highpass_biquad_coef, |
|
highshelf_biquad_coef, |
|
lowshelf_biquad_coef, |
|
equalizer_biquad_coef, |
|
) |
|
from .utils import chain_functions |
|
|
|
|
|
class Clip(nn.Module): |
|
def __init__(self, max: Optional[float] = None, min: Optional[float] = None): |
|
super().__init__() |
|
self.min = min |
|
self.max = max |
|
|
|
def forward(self, x): |
|
if self.min is not None: |
|
x = torch.clip(x, min=self.min) |
|
if self.max is not None: |
|
x = torch.clip(x, max=self.max) |
|
return x |
|
|
|
|
|
def clip_delay_eq_Q(m: nn.Module, Q: float): |
|
if isinstance(m, Delay) and isinstance(m.eq, LowPass): |
|
register_parametrization(m.eq.params, "Q", Clip(max=Q)) |
|
return m |
|
|
|
|
|
float2param = lambda x: nn.Parameter( |
|
torch.tensor(x, dtype=torch.float32) if not isinstance(x, torch.Tensor) else x |
|
) |
|
|
|
STEREO_NORM = math.sqrt(2) |
|
|
|
|
|
def broadcast2stereo(m, args): |
|
x, *_ = args |
|
return x.expand(-1, 2, -1) if x.shape[1] == 1 else x |
|
|
|
|
|
hadamard = lambda x: torch.stack([x.sum(1), x[:, 0] - x[:, 1]], 1) / STEREO_NORM |
|
|
|
|
|
class Hadamard(nn.Module): |
|
def forward(self, x): |
|
return hadamard(x) |
|
|
|
|
|
class FX(nn.Module): |
|
def __init__(self, **kwargs) -> None: |
|
super().__init__() |
|
|
|
self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()}) |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return {k: v.item() for k, v in self.params.items() if v.numel() == 1} |
|
|
|
|
|
class SmoothingCoef(nn.Module): |
|
def forward(self, x): |
|
return x.sigmoid() |
|
|
|
def right_inverse(self, y): |
|
return (y / (1 - y)).log() |
|
|
|
|
|
class CompRatio(nn.Module): |
|
def forward(self, x): |
|
return x.exp() + 1 |
|
|
|
def right_inverse(self, y): |
|
return torch.log(y - 1) |
|
|
|
|
|
class MinMax(nn.Module): |
|
def __init__(self, min=0.0, max: Union[float, torch.Tensor] = 1.0): |
|
super().__init__() |
|
if isinstance(min, torch.Tensor): |
|
self.register_buffer("min", min, persistent=False) |
|
else: |
|
self.min = min |
|
|
|
if isinstance(max, torch.Tensor): |
|
self.register_buffer("max", max, persistent=False) |
|
else: |
|
self.max = max |
|
|
|
self._m = SmoothingCoef() |
|
|
|
def forward(self, x): |
|
return self._m(x) * (self.max - self.min) + self.min |
|
|
|
def right_inverse(self, y): |
|
return self._m.right_inverse((y - self.min) / (self.max - self.min)) |
|
|
|
|
|
class WrappedPositive(nn.Module): |
|
def __init__(self, period): |
|
super().__init__() |
|
self.period = period |
|
|
|
def forward(self, x): |
|
return x.abs() % self.period |
|
|
|
def right_inverse(self, y): |
|
return y |
|
|
|
|
|
class CompressorExpander(FX): |
|
cmp_ratio_min: float = 1 |
|
cmp_ratio_max: float = 20 |
|
|
|
def __init__( |
|
self, |
|
sr: int, |
|
cmp_ratio: float = 2.0, |
|
exp_ratio: float = 0.5, |
|
at_ms: float = 50.0, |
|
rt_ms: float = 50.0, |
|
avg_coef: float = 0.3, |
|
cmp_th: float = -18.0, |
|
exp_th: float = -54.0, |
|
make_up: float = 0.0, |
|
delay: int = 0, |
|
lookahead: bool = False, |
|
max_lookahead: float = 15.0, |
|
): |
|
super().__init__( |
|
cmp_th=cmp_th, |
|
exp_th=exp_th, |
|
make_up=make_up, |
|
avg_coef=avg_coef, |
|
cmp_ratio=cmp_ratio, |
|
exp_ratio=exp_ratio, |
|
) |
|
|
|
self.delay = delay |
|
self.sr = sr |
|
|
|
self.params["at"] = nn.Parameter(ms2coef(torch.tensor(at_ms), sr)) |
|
self.params["rt"] = nn.Parameter(ms2coef(torch.tensor(rt_ms), sr)) |
|
|
|
if lookahead: |
|
self.params["lookahead"] = nn.Parameter(torch.ones(1) / sr * 1000) |
|
register_parametrization( |
|
self.params, "lookahead", WrappedPositive(max_lookahead) |
|
) |
|
sinc_length = int(sr * (max_lookahead + 1) * 0.001) + 1 |
|
left_pad_size = int(sr * 0.001) |
|
self._pad_size = (left_pad_size, sinc_length - left_pad_size - 1) |
|
self.register_buffer( |
|
"_arange", |
|
torch.arange(sinc_length) - left_pad_size, |
|
persistent=False, |
|
) |
|
self.lookahead = lookahead |
|
|
|
register_parametrization(self.params, "at", SmoothingCoef()) |
|
register_parametrization(self.params, "rt", SmoothingCoef()) |
|
register_parametrization(self.params, "avg_coef", SmoothingCoef()) |
|
register_parametrization( |
|
self.params, "cmp_ratio", MinMax(self.cmp_ratio_min, self.cmp_ratio_max) |
|
) |
|
register_parametrization(self.params, "exp_ratio", SmoothingCoef()) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = ( |
|
f"attack: {coef2ms(self.params.at, self.sr).item()} (ms)\n" |
|
f"release: {coef2ms(self.params.rt, self.sr).item()} (ms)\n" |
|
f"avg_coef: {self.params.avg_coef.item()}\n" |
|
f"compressor_ratio: {self.params.cmp_ratio.item()}\n" |
|
f"expander_ratio: {self.params.exp_ratio.item()}\n" |
|
f"compressor_threshold: {self.params.cmp_th.item()} (dB)\n" |
|
f"expander_threshold: {self.params.exp_th.item()} (dB)\n" |
|
f"make_up: {self.params.make_up.item()} (dB)" |
|
) |
|
if self.lookahead: |
|
s += f"\nlookahead: {self.params.lookahead.item()} (ms)" |
|
return s |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"Attack (ms)": coef2ms(self.params.at, self.sr).item(), |
|
"Release (ms)": coef2ms(self.params.rt, self.sr).item(), |
|
"Average Coefficient": self.params.avg_coef.item(), |
|
"Compressor Ratio": self.params.cmp_ratio.item(), |
|
"Expander Ratio": self.params.exp_ratio.item(), |
|
"Compressor Threshold (dB)": self.params.cmp_th.item(), |
|
"Expander Threshold (dB)": self.params.exp_th.item(), |
|
"Make Up (dB)": self.params.make_up.item(), |
|
} | ({"Lookahead (ms)": self.params.lookahead.item()} if self.lookahead else {}) |
|
|
|
def forward(self, x): |
|
if self.lookahead: |
|
lookahead_in_samples = self.params.lookahead * 0.001 * self.sr |
|
sinc_filter = torch.sinc(self._arange - lookahead_in_samples) |
|
lookahead_func = lambda gain: F.conv1d( |
|
F.pad( |
|
gain.view(-1, 1, gain.size(-1)), self._pad_size, mode="replicate" |
|
), |
|
sinc_filter[None, None, :], |
|
).view(*gain.shape) |
|
else: |
|
lookahead_func = lambda x: x |
|
|
|
return compressor_expander( |
|
x.reshape(-1, x.shape[-1]), |
|
lookahead_func=lookahead_func, |
|
**{k: v for k, v in self.params.items() if k != "lookahead"}, |
|
).view(*x.shape) |
|
|
|
|
|
class Panning(FX): |
|
def __init__(self, pan: float = 0.0): |
|
assert pan <= 100 and pan >= -100 |
|
super().__init__(pan=(pan + 100) / 200) |
|
|
|
register_parametrization(self.params, "pan", SmoothingCoef()) |
|
|
|
self.register_forward_pre_hook(broadcast2stereo) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = f"pan: {self.params.pan.item() * 200 - 100}" |
|
return s |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"Pan": self.params.pan.item() * 200 - 100, |
|
} |
|
|
|
def forward(self, x: torch.Tensor): |
|
angle = self.params.pan.view(1) * torch.pi * 0.5 |
|
amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM |
|
return x * amp |
|
|
|
|
|
class StereoWidth(Panning): |
|
def forward(self, x: torch.Tensor): |
|
return chain_functions(hadamard, super().forward, hadamard)(x) |
|
|
|
|
|
class ImpulseResponse(nn.Module): |
|
def forward(self, h): |
|
return torch.cat([torch.ones_like(h[..., :1]), h], dim=-1) |
|
|
|
|
|
class FIR(FX): |
|
def __init__( |
|
self, |
|
length: int, |
|
channels: int = 2, |
|
conv_method: str = "direct", |
|
): |
|
super().__init__(kernel=torch.zeros(channels, length - 1)) |
|
self._padding = length - 1 |
|
self.channels = channels |
|
|
|
match conv_method: |
|
case "direct": |
|
self.conv_func = F.conv1d |
|
case "fft": |
|
self.conv_func = fft_conv1d |
|
case _: |
|
raise ValueError(f"Unknown conv_method: {conv_method}") |
|
|
|
if channels == 2: |
|
self.register_forward_pre_hook(broadcast2stereo) |
|
|
|
def forward(self, x: torch.Tensor): |
|
zero_padded = F.pad(x[..., :-1], (self._padding, 0), "constant", 0) |
|
return x + self.conv_func( |
|
zero_padded, self.params.kernel.flip(1).unsqueeze(1), groups=self.channels |
|
) |
|
|
|
|
|
class QFactor(nn.Module): |
|
def forward(self, x): |
|
return x.exp() |
|
|
|
def right_inverse(self, y): |
|
return y.log() |
|
|
|
|
|
class LowPass(FX): |
|
def __init__( |
|
self, |
|
sr: int, |
|
freq: float = 17500.0, |
|
Q: float = 0.707, |
|
min_freq: float = 200.0, |
|
max_freq: float = 18000, |
|
min_Q: float = 0.5, |
|
max_Q: float = 10.0, |
|
): |
|
super().__init__(freq=freq, Q=Q) |
|
|
|
self.sr = sr |
|
register_parametrization(self.params, "freq", MinMax(min_freq, max_freq)) |
|
register_parametrization(self.params, "Q", MinMax(min_Q, max_Q)) |
|
|
|
def forward(self, x): |
|
return lowpass_biquad( |
|
x, sample_rate=self.sr, cutoff_freq=self.params.freq, Q=self.params.Q |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}" |
|
return s |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"Frequency (Hz)": self.params.freq.item(), |
|
"Q": self.params.Q.item(), |
|
} |
|
|
|
|
|
class HighPass(LowPass): |
|
def __init__( |
|
self, |
|
*args, |
|
freq: float = 200.0, |
|
min_freq: float = 16.0, |
|
max_freq: float = 5300.0, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
*args, freq=freq, min_freq=min_freq, max_freq=max_freq, **kwargs |
|
) |
|
|
|
def forward(self, x): |
|
return highpass_biquad( |
|
x, sample_rate=self.sr, cutoff_freq=self.params.freq, Q=self.params.Q |
|
) |
|
|
|
|
|
class Peak(FX): |
|
def __init__( |
|
self, |
|
sr: int, |
|
gain: float = 0.0, |
|
freq: float = 2000.0, |
|
Q: float = 0.707, |
|
min_freq: float = 33.0, |
|
max_freq: float = 17500.0, |
|
min_Q: float = 0.2, |
|
max_Q: float = 20, |
|
): |
|
super().__init__(freq=freq, Q=Q, gain=gain) |
|
|
|
self.sr = sr |
|
|
|
register_parametrization(self.params, "freq", MinMax(min_freq, max_freq)) |
|
register_parametrization(self.params, "Q", MinMax(min_Q, max_Q)) |
|
|
|
def forward(self, x): |
|
return equalizer_biquad( |
|
x, |
|
sample_rate=self.sr, |
|
center_freq=self.params.freq, |
|
Q=self.params.Q, |
|
gain=self.params.gain, |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}" |
|
return s |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"Frequency (Hz)": self.params.freq.item(), |
|
"Gain (dB)": self.params.gain.item(), |
|
"Q": self.params.Q.item(), |
|
} |
|
|
|
|
|
class LowShelf(FX): |
|
def __init__( |
|
self, |
|
sr: int, |
|
gain: float = 0.0, |
|
freq: float = 115.0, |
|
min_freq: float = 30, |
|
max_freq: float = 200, |
|
): |
|
super().__init__(freq=freq, gain=gain) |
|
|
|
self.sr = sr |
|
register_parametrization(self.params, "freq", MinMax(min_freq, max_freq)) |
|
|
|
self.register_buffer("Q", torch.tensor(0.707), persistent=False) |
|
|
|
def forward(self, x): |
|
return lowshelf_biquad( |
|
x, |
|
sample_rate=self.sr, |
|
cutoff_freq=self.params.freq, |
|
gain=self.params.gain, |
|
Q=self.Q, |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}" |
|
return s |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"Frequency (Hz)": self.params.freq.item(), |
|
"Gain (dB)": self.params.gain.item(), |
|
} |
|
|
|
|
|
class HighShelf(LowShelf): |
|
def __init__( |
|
self, |
|
*args, |
|
freq: float = 4525, |
|
min_freq: float = 750, |
|
max_freq: float = 8300, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
*args, freq=freq, min_freq=min_freq, max_freq=max_freq, **kwargs |
|
) |
|
|
|
def forward(self, x): |
|
return highshelf_biquad( |
|
x, |
|
sample_rate=self.sr, |
|
cutoff_freq=self.params.freq, |
|
gain=self.params.gain, |
|
Q=self.Q, |
|
) |
|
|
|
|
|
def module2coeffs( |
|
m: Union[LowPass, HighPass, Peak, LowShelf, HighShelf], |
|
) -> Tuple[ |
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor |
|
]: |
|
match m: |
|
case LowPass(): |
|
return lowpass_biquad_coef(m.sr, m.params.freq, m.params.Q) |
|
case HighPass(): |
|
return highpass_biquad_coef(m.sr, m.params.freq, m.params.Q) |
|
case Peak(): |
|
return equalizer_biquad_coef(m.sr, m.params.freq, m.params.Q, m.params.gain) |
|
case LowShelf(): |
|
return lowshelf_biquad_coef(m.sr, m.params.freq, m.params.gain, m.Q) |
|
case HighShelf(): |
|
return highshelf_biquad_coef(m.sr, m.params.freq, m.params.gain, m.Q) |
|
case _: |
|
raise ValueError(f"Unknown module: {m}") |
|
|
|
|
|
class AlwaysNegative(nn.Module): |
|
def forward(self, x): |
|
return -F.softplus(x) |
|
|
|
def right_inverse(self, y): |
|
return torch.log(y.neg().exp() - 1) |
|
|
|
|
|
class Reverb(FX): |
|
def __init__(self, ir_len=60000, n_fft=384, hop_length=192, downsample_factor=1): |
|
super().__init__( |
|
log_mag=torch.full((2, n_fft // downsample_factor // 2 + 1), -1.0), |
|
log_mag_delta=torch.full((2, n_fft // downsample_factor // 2 + 1), -5.0), |
|
) |
|
|
|
self.steps = (ir_len - n_fft + hop_length - 1) // hop_length |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.downsample_factor = downsample_factor |
|
|
|
self._noise_angle = nn.Parameter( |
|
torch.rand(2, n_fft // 2 + 1, self.steps) * 2 * torch.pi |
|
) |
|
|
|
self.register_buffer( |
|
"_arange", torch.arange(self.steps, dtype=torch.float32), persistent=False |
|
) |
|
self.spec_forward = Spectrogram(n_fft, hop_length=hop_length, power=None) |
|
self.spec_inverse = InverseSpectrogram( |
|
n_fft, |
|
hop_length=hop_length, |
|
) |
|
|
|
register_parametrization(self.params, "log_mag", AlwaysNegative()) |
|
register_parametrization(self.params, "log_mag_delta", AlwaysNegative()) |
|
|
|
self.register_forward_pre_hook(broadcast2stereo) |
|
|
|
def forward(self, x): |
|
h = x |
|
H = self.spec_forward(h) |
|
|
|
log_mag = self.params.log_mag |
|
log_mag_delta = self.params.log_mag_delta |
|
|
|
if self.downsample_factor > 1: |
|
log_mag = F.interpolate( |
|
log_mag.unsqueeze(0), |
|
size=self._noise_angle.size(1), |
|
align_corners=True, |
|
mode="linear", |
|
).squeeze(0) |
|
log_mag_delta = F.interpolate( |
|
log_mag_delta.unsqueeze(0), |
|
size=self._noise_angle.size(1), |
|
align_corners=True, |
|
mode="linear", |
|
).squeeze(0) |
|
|
|
ir_2d = torch.exp( |
|
log_mag.unsqueeze(-1) |
|
+ log_mag_delta.unsqueeze(-1) * self._arange |
|
+ self._noise_angle * 1j |
|
) |
|
|
|
padded_H = F.pad(H.flatten(1, 2), (ir_2d.shape[-1] - 1, 0)) |
|
|
|
H = F.conv1d( |
|
padded_H, |
|
hadamard(ir_2d.unsqueeze(0)).flatten(1, 2).flip(-1).transpose(0, 1), |
|
groups=H.shape[2] * 2, |
|
).view(*H.shape) |
|
|
|
h = self.spec_inverse(H) |
|
return h |
|
|
|
|
|
class Delay(FX): |
|
min_delay: float = 100 |
|
max_delay: float = 1000 |
|
|
|
def __init__( |
|
self, |
|
sr: int, |
|
delay=200.0, |
|
feedback=0.1, |
|
gain=0.1, |
|
ir_duration: float = 2, |
|
eq: Optional[nn.Module] = None, |
|
recursive_eq=False, |
|
): |
|
super().__init__( |
|
delay=delay, |
|
feedback=feedback, |
|
gain=gain, |
|
) |
|
self.sr = sr |
|
self.ir_length = int(sr * max(ir_duration, self.max_delay * 0.002)) |
|
|
|
register_parametrization( |
|
self.params, "delay", MinMax(self.min_delay, self.max_delay) |
|
) |
|
register_parametrization(self.params, "feedback", SmoothingCoef()) |
|
register_parametrization(self.params, "gain", SmoothingCoef()) |
|
|
|
self.eq = eq |
|
self.recursive_eq = recursive_eq |
|
|
|
self.register_buffer( |
|
"_arange", torch.arange(self.ir_length, dtype=torch.float32) |
|
) |
|
|
|
self.odd_pan = Panning(0) |
|
self.even_pan = Panning(0) |
|
|
|
def forward(self, x): |
|
assert x.size(1) == 1, x.size() |
|
delay_in_samples = self.sr * self.params.delay * 0.001 |
|
num_delays = self.ir_length // int(delay_in_samples.item() + 1) |
|
series = torch.arange(1, num_delays + 1, device=x.device) |
|
decays = self.params.feedback ** (series - 1) |
|
|
|
if self.recursive_eq and self.eq is not None: |
|
sinc_index = self._arange - delay_in_samples |
|
single_sinc_filter = torch.sinc(sinc_index) |
|
eq_sinc_filter = self.eq(single_sinc_filter) |
|
H = torch.fft.rfft(eq_sinc_filter) |
|
H_powered = torch.polar( |
|
H.abs() ** series.unsqueeze(-1), H.angle() * series.unsqueeze(-1) |
|
) |
|
sinc_filters = torch.fft.irfft(H_powered, n=self.ir_length) |
|
else: |
|
delays_in_samples = delay_in_samples * series |
|
sinc_indexes = self._arange - delays_in_samples.unsqueeze(-1) |
|
sinc_filters = torch.sinc(sinc_indexes) |
|
|
|
decayed_sinc_filters = sinc_filters * decays.unsqueeze(-1) |
|
return self._filter(x, decayed_sinc_filters) |
|
|
|
def _filter(self, x: torch.Tensor, decayed_sinc_filters: torch.Tensor): |
|
odd_delay_filters = torch.sum(decayed_sinc_filters[::2], 0) |
|
even_delay_filters = torch.sum(decayed_sinc_filters[1::2], 0) |
|
stacked_filters = torch.stack([odd_delay_filters, even_delay_filters]) |
|
|
|
if self.eq is not None and not self.recursive_eq: |
|
stacked_filters = self.eq(stacked_filters) |
|
|
|
gained_odd_even_filters = stacked_filters * self.params.gain |
|
padded_x = F.pad(x, (gained_odd_even_filters.size(-1) - 1, 0)) |
|
conv1d = F.conv1d if x.size(-1) > 44100 * 20 else fft_conv1d |
|
return sum( |
|
[ |
|
panner(s) |
|
for panner, s in zip( |
|
[self.odd_pan, self.even_pan], |
|
|
|
conv1d( |
|
padded_x, |
|
gained_odd_even_filters.flip(-1).unsqueeze(1), |
|
).chunk(2, 1), |
|
) |
|
] |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = ( |
|
f"delay: {self.sr * self.params.delay.item() * 0.001} (samples)\n" |
|
f"feedback: {self.params.feedback.item()}\n" |
|
f"gain: {self.params.gain.item()}" |
|
) |
|
return s |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"Delay (ms)": self.params.delay.item(), |
|
"Feedback (dB)": self.params.feedback.log10().mul(20).item(), |
|
"Gain (dB)": self.params.gain.log10().mul(20).item(), |
|
"Odd delays": self.odd_pan.toJSON(), |
|
"Even delays": self.even_pan.toJSON(), |
|
} |
|
|
|
|
|
class SurrogateDelay(Delay): |
|
def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.dropout = dropout |
|
self.straight_through = straight_through |
|
self.log_damp = nn.Parameter(torch.ones(1) * -0.01) |
|
register_parametrization(self, "log_damp", AlwaysNegative()) |
|
|
|
def forward(self, x): |
|
assert x.size(1) == 1, x.size() |
|
if not self.training: |
|
return super().forward(x) |
|
|
|
log_damp = self.log_damp |
|
delay_in_samples = self.sr * self.params.delay * 0.001 |
|
num_delays = self.ir_length // int(delay_in_samples.item() + 1) |
|
series = torch.arange(1, num_delays + 1, device=x.device) |
|
decays = self.params.feedback ** (series - 1) |
|
|
|
if self.recursive_eq and self.eq is not None: |
|
exp_factor = self._arange[: self.ir_length // 2 + 1] |
|
damped_exp = torch.exp( |
|
log_damp * exp_factor |
|
- 1j * delay_in_samples / self.ir_length * 2 * torch.pi * exp_factor |
|
) |
|
sinc_filter = torch.fft.irfft(damped_exp, n=self.ir_length) |
|
if self.straight_through: |
|
sinc_index = self._arange - delay_in_samples |
|
hard_sinc_filter = torch.sinc(sinc_index) |
|
sinc_filter = sinc_filter + (hard_sinc_filter - sinc_filter).detach() |
|
|
|
eq_sinc_filter = self.eq(sinc_filter) |
|
H = torch.fft.rfft(eq_sinc_filter) |
|
|
|
|
|
H_powered = torch.polar( |
|
H.abs() ** series.unsqueeze(-1), H.angle() * series.unsqueeze(-1) |
|
) |
|
sinc_filters = torch.fft.irfft(H_powered, n=self.ir_length) |
|
else: |
|
exp_factors = series.unsqueeze(-1) * self._arange[: self.ir_length // 2 + 1] |
|
damped_exps = torch.exp( |
|
log_damp * exp_factors |
|
- 1j * delay_in_samples / self.ir_length * 2 * torch.pi * exp_factors |
|
) |
|
sinc_filters = torch.fft.irfft(damped_exps, n=self.ir_length) |
|
if self.straight_through: |
|
delays_in_samples = delay_in_samples * series |
|
sinc_indexes = self._arange - delays_in_samples.unsqueeze(-1) |
|
hard_sinc_filters = torch.sinc(sinc_indexes) |
|
sinc_filters = ( |
|
sinc_filters + (hard_sinc_filters - sinc_filters).detach() |
|
) |
|
|
|
decayed_sinc_filters = sinc_filters * decays.unsqueeze(-1) |
|
|
|
dropout_mask = torch.rand(x.size(0), device=x.device) < self.dropout |
|
if not torch.any(dropout_mask): |
|
return self._filter(x, decayed_sinc_filters) |
|
elif torch.all(dropout_mask): |
|
return super().forward(x) |
|
|
|
out = torch.zeros((x.size(0), 2, x.size(2)), device=x.device) |
|
out[~dropout_mask] = self._filter(x[~dropout_mask], decayed_sinc_filters) |
|
out[dropout_mask] = super().forward(x[dropout_mask]) |
|
return out |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
return super().extra_repr() + f"\ndamp: {self.log_damp.exp().item()}" |
|
|
|
|
|
class FSDelay(FX): |
|
def __init__( |
|
self, |
|
sr: int, |
|
delay=200.0, |
|
feedback=0.1, |
|
gain=0.1, |
|
ir_duration: float = 6, |
|
eq: Optional[LowPass] = None, |
|
recursive_eq=False, |
|
): |
|
super().__init__( |
|
delay=delay, |
|
feedback=feedback, |
|
gain=gain, |
|
) |
|
self.sr = sr |
|
self.ir_length = int(sr * max(ir_duration, Delay.max_delay * 0.002)) |
|
|
|
register_parametrization( |
|
self.params, "delay", MinMax(Delay.min_delay, Delay.max_delay) |
|
) |
|
register_parametrization(self.params, "gain", SmoothingCoef()) |
|
|
|
T_60 = ir_duration * 0.75 |
|
max_delay_in_samples = sr * Delay.max_delay * 0.001 |
|
maximum_decay = db2amp(torch.tensor(-60 / sr / T_60 * max_delay_in_samples)) |
|
register_parametrization(self.params, "feedback", MinMax(0, maximum_decay)) |
|
|
|
self.eq = eq |
|
self.recursive_eq = recursive_eq |
|
|
|
self.odd_pan = Panning(0) |
|
self.even_pan = Panning(0) |
|
|
|
self.register_buffer( |
|
"_arange", torch.arange(self.ir_length, dtype=torch.float32) |
|
) |
|
|
|
def _get_h(self): |
|
freqs = self._arange[: self.ir_length // 2 + 1] / self.ir_length * 2 * torch.pi |
|
delay_in_samples = self.sr * self.params.delay * 0.001 |
|
|
|
|
|
Dinv = torch.exp(1j * freqs * delay_in_samples) |
|
Dinv2 = torch.exp(2j * freqs * delay_in_samples) |
|
if self.recursive_eq and self.eq is not None: |
|
b0, b1, b2, a0, a1, a2 = module2coeffs(self.eq) |
|
z_inv = torch.exp(-1j * freqs) |
|
z_inv2 = torch.exp(-2j * freqs) |
|
eq_H = (b0 + b1 * z_inv + b2 * z_inv2) / (a0 + a1 * z_inv + a2 * z_inv2) |
|
damp = eq_H * self.params.feedback |
|
det = Dinv2 - damp * damp |
|
else: |
|
damp = torch.full_like(Dinv, self.params.feedback) + 0j |
|
det = Dinv2 - self.params.feedback.square() |
|
inv_Dinv_m_A = torch.stack([Dinv, damp], 0) / det |
|
h = torch.fft.irfft(inv_Dinv_m_A, n=self.ir_length) * self.params.gain |
|
|
|
if self.eq is not None and not self.recursive_eq: |
|
h = self.eq(h) |
|
return h |
|
|
|
def forward(self, x): |
|
assert x.size(1) == 1, x.size() |
|
h = self._get_h() |
|
|
|
padded_x = F.pad(x, (h.size(-1) - 1, 0)) |
|
conv1d = F.conv1d if x.size(-1) > 44100 * 20 else fft_conv1d |
|
return sum( |
|
[ |
|
panner(s) |
|
for panner, s in zip( |
|
[self.odd_pan, self.even_pan], |
|
conv1d( |
|
padded_x, |
|
h.flip(-1).unsqueeze(1), |
|
).chunk(2, 1), |
|
) |
|
] |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
s = ( |
|
f"delay: {self.sr * self.params.delay.item() * 0.001} (samples)\n" |
|
f"feedback: {self.params.feedback.item()}\n" |
|
f"gain: {self.params.gain.item()}" |
|
) |
|
return s |
|
|
|
|
|
class FSSurrogateDelay(FSDelay): |
|
def __init__(self, *args, straight_through=False, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.straight_through = straight_through |
|
self.log_damp = nn.Parameter(torch.ones(1) * -0.0001) |
|
register_parametrization(self, "log_damp", AlwaysNegative()) |
|
|
|
def _get_h(self): |
|
if not self.training: |
|
return super()._get_h() |
|
|
|
log_damp = self.log_damp |
|
delay_in_samples = self.sr * self.params.delay * 0.001 |
|
|
|
exp_factor = self._arange[: self.ir_length // 2 + 1] |
|
freqs = exp_factor / self.ir_length * 2 * torch.pi |
|
D = torch.exp(log_damp * exp_factor - 1j * delay_in_samples * freqs) |
|
D2 = torch.exp(log_damp * exp_factor * 2 - 2j * delay_in_samples * freqs) |
|
|
|
if self.straight_through: |
|
D_orig = torch.exp(-1j * delay_in_samples * freqs) |
|
D2_orig = torch.exp(-2j * delay_in_samples * freqs) |
|
D = torch.stack([D, D_orig], 0) |
|
D2 = torch.stack([D2, D2_orig], 0) |
|
|
|
if self.recursive_eq and self.eq is not None: |
|
b0, b1, b2, a0, a1, a2 = module2coeffs(self.eq) |
|
z_inv = torch.exp(-1j * freqs) |
|
z_inv2 = torch.exp(-2j * freqs) |
|
eq_H = (b0 + b1 * z_inv + b2 * z_inv2) / (a0 + a1 * z_inv + a2 * z_inv2) |
|
damp = eq_H * self.params.feedback |
|
odd_H = D / (1 - damp * damp * D2) |
|
even_H = odd_H * D * damp |
|
else: |
|
damp = torch.full_like(D, self.params.feedback) + 0j |
|
odd_H = D / (1 - self.params.feedback.square() * D2) |
|
even_H = odd_H * D * self.params.feedback |
|
|
|
inv_Dinv_m_A = torch.stack([odd_H, even_H], 0) |
|
h = torch.fft.irfft(inv_Dinv_m_A, n=self.ir_length) |
|
|
|
if self.straight_through: |
|
damped_h, orig_h = h.unbind(1) |
|
h = damped_h + (orig_h - damped_h).detach() |
|
|
|
if self.eq is not None and not self.recursive_eq: |
|
h = self.eq(h) |
|
return h * self.params.gain |
|
|
|
def extra_repr(self) -> str: |
|
with torch.no_grad(): |
|
return super().extra_repr() + f"\ndamp: {self.log_damp.exp().item()}" |
|
|
|
|
|
class SendFXsAndSum(FX): |
|
def __init__(self, *args, cross_send=True, pan_direct=False): |
|
super().__init__( |
|
**( |
|
{ |
|
f"sends_{i}": torch.full([len(args) - i - 1], 0.01) |
|
for i in range(len(args) - 1) |
|
} |
|
if cross_send |
|
else {} |
|
) |
|
) |
|
self.effects = nn.ModuleList(args) |
|
if pan_direct: |
|
self.pan = Panning() |
|
|
|
if cross_send: |
|
for i in range(len(args) - 1): |
|
register_parametrization(self.params, f"sends_{i}", SmoothingCoef()) |
|
|
|
def forward(self, x): |
|
if hasattr(self, "pan"): |
|
di = self.pan(x) |
|
else: |
|
di = x |
|
|
|
if len(self.params) == 0: |
|
return di, reduce( |
|
lambda x, y: x[..., : y.shape[-1]] + y[..., : x.shape[-1]], |
|
map(lambda f: f(x), self.effects), |
|
) |
|
|
|
def f(states, ps): |
|
x, cum_sends = states |
|
m, send_gains = ps |
|
h = m(cum_sends[0]) |
|
return ( |
|
x[..., : h.shape[-1]] + h[..., : x.shape[-1]], |
|
( |
|
None |
|
if cum_sends.size(0) == 1 |
|
else cum_sends[1:, ..., : h.shape[-1]] |
|
+ send_gains[:, None, None, None] * h[..., : cum_sends.shape[-1]] |
|
), |
|
) |
|
|
|
return ( |
|
di, |
|
reduce( |
|
f, |
|
zip( |
|
self.effects, |
|
[self.params[f"sends_{i}"] for i in range(len(self.effects) - 1)] |
|
+ [None], |
|
), |
|
( |
|
torch.zeros_like(x), |
|
x.unsqueeze(0).expand(len(self.effects), -1, -1, -1), |
|
), |
|
)[0], |
|
) |
|
|
|
|
|
class UniLossLess(nn.Module): |
|
def forward(self, x): |
|
tri = x.triu(1) |
|
return torch.linalg.matrix_exp(tri - tri.T) |
|
|
|
|
|
class FDN(FX): |
|
max_delay = 100 |
|
|
|
def __init__( |
|
self, |
|
sr: int, |
|
ir_duration: float = 1.0, |
|
delays=(997, 1153, 1327, 1559, 1801, 2099), |
|
trainable_delay=False, |
|
num_decay_freq=1, |
|
delay_independent_decay=False, |
|
eq: Optional[nn.Module] = None, |
|
): |
|
|
|
num_delays = len(delays) |
|
super().__init__( |
|
b=torch.ones(num_delays, 2) / num_delays, |
|
c=torch.zeros(2, num_delays), |
|
U=torch.randn(num_delays, num_delays) / num_delays**0.5, |
|
gamma=torch.rand( |
|
num_decay_freq, num_delays if not delay_independent_decay else 1 |
|
) |
|
* 0.2 |
|
+ 0.4, |
|
|
|
) |
|
|
|
self.sr = sr |
|
self.ir_length = int(sr * ir_duration) |
|
|
|
|
|
T_60 = ir_duration * 0.75 |
|
delays = torch.tensor(delays) |
|
if delay_independent_decay: |
|
gamma_max = db2amp(-60 / sr / T_60 * delays.min()) |
|
else: |
|
gamma_max = db2amp(-60 / sr / T_60 * delays) |
|
|
|
register_parametrization(self.params, "gamma", MinMax(0, gamma_max)) |
|
register_parametrization(self.params, "U", UniLossLess()) |
|
|
|
if not trainable_delay: |
|
self.register_buffer( |
|
"delays", |
|
delays, |
|
) |
|
else: |
|
self.params["delays"] = nn.Parameter(delays / sr * 1000) |
|
register_parametrization(self.params, "delays", MinMax(0, self.max_delay)) |
|
|
|
self.register_forward_pre_hook(broadcast2stereo) |
|
|
|
self.eq = eq |
|
|
|
def forward(self, x): |
|
conv1d = F.conv1d if x.size(-1) > 44100 * 20 else fft_conv1d |
|
|
|
c = self.params.c + 0j |
|
b = self.params.b + 0j |
|
|
|
gamma = self.params.gamma |
|
delays = self.delays if hasattr(self, "delays") else self.params.delays |
|
|
|
if gamma.size(0) > 1: |
|
gamma = F.interpolate( |
|
gamma.T.unsqueeze(1), |
|
size=self.ir_length // 2 + 1, |
|
align_corners=True, |
|
mode="linear", |
|
).transpose(0, 2) |
|
|
|
if gamma.size(2) == 1: |
|
gamma = gamma ** (delays / delays.min()) |
|
|
|
A = self.params.U * gamma |
|
|
|
freqs = ( |
|
torch.arange(self.ir_length // 2 + 1, device=x.device) |
|
/ self.ir_length |
|
* 2 |
|
* torch.pi |
|
) |
|
invD = torch.exp(1j * freqs[:, None] * delays) |
|
|
|
H = c @ torch.linalg.solve(torch.diag_embed(invD) - A, b) |
|
|
|
h = torch.fft.irfft(H.permute(1, 2, 0), n=self.ir_length) |
|
|
|
if self.eq is not None: |
|
h = self.eq(h) |
|
|
|
|
|
return conv1d( |
|
F.pad(x, (self.ir_length - 1, 0)), |
|
h.flip(-1), |
|
) |
|
|
|
def toJSON(self) -> dict[str, Any]: |
|
return { |
|
"T60 (s)": { |
|
f"{f:.2f} Hz": g.item() |
|
for f, g in zip( |
|
torch.linspace(0, 22050, self.params.gamma.numel()), |
|
-60 * self.delays.min() / amp2db(self.params.gamma) / 44100, |
|
) |
|
}, |
|
"Gain (dB, approx)": amp2db( |
|
torch.linalg.norm(self.params.b) * torch.linalg.norm(self.params.c) |
|
).item(), |
|
} |
|
|