Spaces:
Runtime error
Runtime error
Hugo Flores
commited on
Commit
·
5a0a80a
1
Parent(s):
91f8638
beat tracker bugfixes
Browse files- requirements.txt +2 -1
- vampnet/beats.py +2 -5
- vampnet/interface.py +41 -10
- vampnet/modules/base.py +1 -2
requirements.txt
CHANGED
|
@@ -2,7 +2,8 @@ argbind>=0.3.1
|
|
| 2 |
pytorch-ignite
|
| 3 |
rich
|
| 4 |
audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info
|
| 5 |
-
lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@
|
|
|
|
| 6 |
tqdm
|
| 7 |
tensorboard
|
| 8 |
google-cloud-logging==2.2.0
|
|
|
|
| 2 |
pytorch-ignite
|
| 3 |
rich
|
| 4 |
audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info
|
| 5 |
+
lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@hf/vampnet-temp
|
| 6 |
+
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
|
| 7 |
tqdm
|
| 8 |
tensorboard
|
| 9 |
google-cloud-logging==2.2.0
|
vampnet/beats.py
CHANGED
|
@@ -200,13 +200,10 @@ class BeatTracker:
|
|
| 200 |
|
| 201 |
|
| 202 |
class WaveBeat(BeatTracker):
|
| 203 |
-
def __init__(self,
|
| 204 |
from wavebeat.dstcn import dsTCNModel
|
| 205 |
|
| 206 |
-
|
| 207 |
-
assert len(ckpts) > 0, f"no checkpoints found for wavebeat in {ckpt_dir}"
|
| 208 |
-
|
| 209 |
-
model = dsTCNModel.load_from_checkpoint(ckpts[-1])
|
| 210 |
model.eval()
|
| 211 |
|
| 212 |
self.device = device
|
|
|
|
| 200 |
|
| 201 |
|
| 202 |
class WaveBeat(BeatTracker):
|
| 203 |
+
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 204 |
from wavebeat.dstcn import dsTCNModel
|
| 205 |
|
| 206 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path)
|
|
|
|
|
|
|
|
|
|
| 207 |
model.eval()
|
| 208 |
|
| 209 |
self.device = device
|
vampnet/interface.py
CHANGED
|
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
| 3 |
import math
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
from audiotools import AudioSignal
|
| 7 |
import tqdm
|
| 8 |
|
|
@@ -50,7 +51,10 @@ class Interface:
|
|
| 50 |
|
| 51 |
def s2t(self, seconds: float):
|
| 52 |
"""seconds to tokens"""
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def s2t2s(self, seconds: float):
|
| 56 |
"""seconds to tokens to seconds"""
|
|
@@ -94,11 +98,12 @@ class Interface:
|
|
| 94 |
signal: AudioSignal,
|
| 95 |
before_beat_s: float = 0.1,
|
| 96 |
after_beat_s: float = 0.1,
|
| 97 |
-
mask_downbeats:
|
| 98 |
-
mask_upbeats:
|
| 99 |
downbeat_downsample_factor: int = None,
|
| 100 |
beat_downsample_factor: int = None,
|
| 101 |
-
|
|
|
|
| 102 |
):
|
| 103 |
"""make a beat synced mask. that is, make a mask that
|
| 104 |
places 1s at and around the beat, and 0s everywhere else.
|
|
@@ -112,7 +117,9 @@ class Interface:
|
|
| 112 |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
| 113 |
|
| 114 |
# remove downbeats from beats
|
| 115 |
-
beats_z = beats_z[~torch.isin(beats_z, downbeats_z)]
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# make the mask
|
| 118 |
seq_len = self.s2t(signal.duration)
|
|
@@ -138,16 +145,26 @@ class Interface:
|
|
| 138 |
|
| 139 |
if mask_upbeats:
|
| 140 |
for beat_idx in beats_z:
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
if mask_downbeats:
|
| 144 |
for downbeat_idx in downbeats_z:
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if invert:
|
| 148 |
mask = 1 - mask
|
| 149 |
|
| 150 |
-
return mask
|
| 151 |
|
| 152 |
def coarse_to_fine(
|
| 153 |
self,
|
|
@@ -293,6 +310,7 @@ class Interface:
|
|
| 293 |
debug=False,
|
| 294 |
swap_prefix_suffix=False,
|
| 295 |
ext_mask=None,
|
|
|
|
| 296 |
**kwargs
|
| 297 |
):
|
| 298 |
z = self.encode(signal)
|
|
@@ -319,7 +337,8 @@ class Interface:
|
|
| 319 |
|
| 320 |
_cz = cz.clone()
|
| 321 |
cz_mask = None
|
| 322 |
-
|
|
|
|
| 323 |
# add noise
|
| 324 |
cz_masked, cz_mask = self.coarse.add_noise(
|
| 325 |
_cz, r=1.0-intensity,
|
|
@@ -428,8 +447,9 @@ class Interface:
|
|
| 428 |
def variation(
|
| 429 |
self,
|
| 430 |
signal: AudioSignal,
|
| 431 |
-
overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
|
| 432 |
verbose: bool = False,
|
|
|
|
|
|
|
| 433 |
**kwargs
|
| 434 |
):
|
| 435 |
signal = signal.clone()
|
|
@@ -442,6 +462,9 @@ class Interface:
|
|
| 442 |
math.ceil(signal.duration / self.coarse.chunk_size_s)
|
| 443 |
* self.coarse.chunk_size_s
|
| 444 |
)
|
|
|
|
|
|
|
|
|
|
| 445 |
hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
|
| 446 |
original_length = signal.length
|
| 447 |
|
|
@@ -460,10 +483,18 @@ class Interface:
|
|
| 460 |
signal.samples[i,...], signal.sample_rate
|
| 461 |
)
|
| 462 |
sig.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
out_z = self.coarse_vamp_v2(
|
| 464 |
sig,
|
| 465 |
num_vamps=1,
|
| 466 |
swap_prefix_suffix=False,
|
|
|
|
|
|
|
| 467 |
**kwargs
|
| 468 |
)
|
| 469 |
if self.c2f is not None:
|
|
|
|
| 3 |
import math
|
| 4 |
|
| 5 |
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
from audiotools import AudioSignal
|
| 8 |
import tqdm
|
| 9 |
|
|
|
|
| 51 |
|
| 52 |
def s2t(self, seconds: float):
|
| 53 |
"""seconds to tokens"""
|
| 54 |
+
if isinstance(seconds, np.ndarray):
|
| 55 |
+
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
| 56 |
+
else:
|
| 57 |
+
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
| 58 |
|
| 59 |
def s2t2s(self, seconds: float):
|
| 60 |
"""seconds to tokens to seconds"""
|
|
|
|
| 98 |
signal: AudioSignal,
|
| 99 |
before_beat_s: float = 0.1,
|
| 100 |
after_beat_s: float = 0.1,
|
| 101 |
+
mask_downbeats: bool = True,
|
| 102 |
+
mask_upbeats: bool = True,
|
| 103 |
downbeat_downsample_factor: int = None,
|
| 104 |
beat_downsample_factor: int = None,
|
| 105 |
+
dropout: float = 0.7,
|
| 106 |
+
invert: bool = True,
|
| 107 |
):
|
| 108 |
"""make a beat synced mask. that is, make a mask that
|
| 109 |
places 1s at and around the beat, and 0s everywhere else.
|
|
|
|
| 117 |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
| 118 |
|
| 119 |
# remove downbeats from beats
|
| 120 |
+
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
|
| 121 |
+
beats_z = beats_z.tolist()
|
| 122 |
+
downbeats_z = downbeats_z.tolist()
|
| 123 |
|
| 124 |
# make the mask
|
| 125 |
seq_len = self.s2t(signal.duration)
|
|
|
|
| 145 |
|
| 146 |
if mask_upbeats:
|
| 147 |
for beat_idx in beats_z:
|
| 148 |
+
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
| 149 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 150 |
+
_m = torch.ones(num_steps, device=self.device)
|
| 151 |
+
_m = torch.nn.functional.dropout(_m, p=dropout)
|
| 152 |
+
|
| 153 |
+
mask[_slice[0]:_slice[1]] = _m
|
| 154 |
|
| 155 |
if mask_downbeats:
|
| 156 |
for downbeat_idx in downbeats_z:
|
| 157 |
+
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
| 158 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 159 |
+
_m = torch.ones(num_steps, device=self.device)
|
| 160 |
+
_m = torch.nn.functional.dropout(_m, p=dropout)
|
| 161 |
+
|
| 162 |
+
mask[_slice[0]:_slice[1]] = _m
|
| 163 |
|
| 164 |
if invert:
|
| 165 |
mask = 1 - mask
|
| 166 |
|
| 167 |
+
return mask[None, None, :].bool().long()
|
| 168 |
|
| 169 |
def coarse_to_fine(
|
| 170 |
self,
|
|
|
|
| 310 |
debug=False,
|
| 311 |
swap_prefix_suffix=False,
|
| 312 |
ext_mask=None,
|
| 313 |
+
verbose=False,
|
| 314 |
**kwargs
|
| 315 |
):
|
| 316 |
z = self.encode(signal)
|
|
|
|
| 337 |
|
| 338 |
_cz = cz.clone()
|
| 339 |
cz_mask = None
|
| 340 |
+
range_fn = tqdm.trange if verbose else range
|
| 341 |
+
for _ in range_fn(num_vamps):
|
| 342 |
# add noise
|
| 343 |
cz_masked, cz_mask = self.coarse.add_noise(
|
| 344 |
_cz, r=1.0-intensity,
|
|
|
|
| 447 |
def variation(
|
| 448 |
self,
|
| 449 |
signal: AudioSignal,
|
|
|
|
| 450 |
verbose: bool = False,
|
| 451 |
+
beat_mask: bool = False,
|
| 452 |
+
beat_mask_kwargs: dict = {},
|
| 453 |
**kwargs
|
| 454 |
):
|
| 455 |
signal = signal.clone()
|
|
|
|
| 462 |
math.ceil(signal.duration / self.coarse.chunk_size_s)
|
| 463 |
* self.coarse.chunk_size_s
|
| 464 |
)
|
| 465 |
+
# eventually we DO want overlap, but we want overlap-replace not
|
| 466 |
+
# overlap-add
|
| 467 |
+
overlap_hop_ratio = 1.0
|
| 468 |
hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
|
| 469 |
original_length = signal.length
|
| 470 |
|
|
|
|
| 483 |
signal.samples[i,...], signal.sample_rate
|
| 484 |
)
|
| 485 |
sig.to(self.device)
|
| 486 |
+
|
| 487 |
+
if beat_mask:
|
| 488 |
+
ext_mask = self.make_beat_mask(sig, **beat_mask_kwargs)
|
| 489 |
+
else:
|
| 490 |
+
ext_mask = None
|
| 491 |
+
|
| 492 |
out_z = self.coarse_vamp_v2(
|
| 493 |
sig,
|
| 494 |
num_vamps=1,
|
| 495 |
swap_prefix_suffix=False,
|
| 496 |
+
ext_mask=ext_mask,
|
| 497 |
+
verbose=verbose,
|
| 498 |
**kwargs
|
| 499 |
)
|
| 500 |
if self.c2f is not None:
|
vampnet/modules/base.py
CHANGED
|
@@ -103,8 +103,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 103 |
# add the external mask if we were given one
|
| 104 |
if ext_mask is not None:
|
| 105 |
assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
| 106 |
-
|
| 107 |
-
mask = (mask + ext_mask).bool().long()
|
| 108 |
|
| 109 |
x = x * (1 - mask) + random_x * mask
|
| 110 |
return x, mask
|
|
|
|
| 103 |
# add the external mask if we were given one
|
| 104 |
if ext_mask is not None:
|
| 105 |
assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
| 106 |
+
mask = (mask * ext_mask).bool().long()
|
|
|
|
| 107 |
|
| 108 |
x = x * (1 - mask) + random_x * mask
|
| 109 |
return x, mask
|