Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
3346920
1
Parent(s):
3445a71
more sampling fixes
Browse files- sample.py +70 -0
- scripts/{utils/vamp_folder.py → exp/experiment.py} +6 -7
- scripts/utils/parallel-gpu.sh +0 -23
- vampnet/modules/transformer.py +23 -34
sample.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import argbind
|
| 3 |
+
|
| 4 |
+
import audiotools as at
|
| 5 |
+
|
| 6 |
+
from vampnet.interface import Interface
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger()
|
| 10 |
+
logger.setLevel(logging.DEBUG)
|
| 11 |
+
|
| 12 |
+
Interface = argbind.bind(Interface)
|
| 13 |
+
|
| 14 |
+
with open("conf/interface/spotdl.yml") as f:
|
| 15 |
+
conf = yaml.safe_load(f)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
with argbind.scope(conf):
|
| 19 |
+
interface = Interface()
|
| 20 |
+
interface.to("cuda")
|
| 21 |
+
|
| 22 |
+
loader = at.data.datasets.AudioLoader(sources=[
|
| 23 |
+
"input.wav",
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
dataset = at.data.datasets.AudioDataset(
|
| 27 |
+
loader,
|
| 28 |
+
sample_rate=interface.codec.sample_rate,
|
| 29 |
+
duration=interface.coarse.chunk_size_s,
|
| 30 |
+
n_examples=200,
|
| 31 |
+
without_replacement=True,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
def load_random_audio():
|
| 36 |
+
index = np.random.randint(0, len(dataset))
|
| 37 |
+
sig = dataset[index]["signal"]
|
| 38 |
+
sig = interface.preprocess(sig)
|
| 39 |
+
|
| 40 |
+
return sig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
sig = load_random_audio()
|
| 44 |
+
z = interface.encode(sig)
|
| 45 |
+
|
| 46 |
+
sig.write('input.wav')
|
| 47 |
+
|
| 48 |
+
from vampnet import mask as pmask
|
| 49 |
+
|
| 50 |
+
# build the mask
|
| 51 |
+
mask = pmask.linear_random(z, 1.0)
|
| 52 |
+
|
| 53 |
+
print("coarse")
|
| 54 |
+
zv, mask_z = interface.coarse_vamp(
|
| 55 |
+
z,
|
| 56 |
+
mask=mask,
|
| 57 |
+
sampling_steps=36,
|
| 58 |
+
temperature=8.0,
|
| 59 |
+
return_mask=True,
|
| 60 |
+
typical_filtering=False,
|
| 61 |
+
# typical_mass=data[typical_mass],
|
| 62 |
+
# typical_min_tokens=data[typical_min_tokens],
|
| 63 |
+
gen_fn=interface.coarse.generate,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
print("coarse2fine")
|
| 67 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8)
|
| 68 |
+
|
| 69 |
+
sig = interface.to_signal(zv).cpu()
|
| 70 |
+
sig.write('output-t=8.wav')
|
scripts/{utils/vamp_folder.py → exp/experiment.py}
RENAMED
|
@@ -119,13 +119,15 @@ def beat_mask(ctx_time):
|
|
| 119 |
def wrapper(sig, interface):
|
| 120 |
beat_mask = interface.make_beat_mask(
|
| 121 |
sig,
|
| 122 |
-
before_beat_s=
|
| 123 |
-
after_beat_s=ctx_time,
|
| 124 |
invert=True
|
| 125 |
)
|
|
|
|
| 126 |
z = interface.encode(sig)
|
|
|
|
| 127 |
zv = interface.coarse_vamp(
|
| 128 |
-
z, beat_mask
|
| 129 |
)
|
| 130 |
|
| 131 |
zv = interface.coarse_to_fine(zv)
|
|
@@ -185,9 +187,6 @@ EXP_REGISTRY["sampling-steps"] = {
|
|
| 185 |
|
| 186 |
|
| 187 |
EXP_REGISTRY["musical-sampling"] = {
|
| 188 |
-
"baseline": baseline,
|
| 189 |
-
"codec": reconstructed,
|
| 190 |
-
**{f"downsample_{x}x": CoarseCond(4, downsample_factor=x) for x in [16, 32]},
|
| 191 |
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
| 192 |
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
| 193 |
}
|
|
@@ -195,7 +194,7 @@ EXP_REGISTRY["musical-sampling"] = {
|
|
| 195 |
@argbind.bind(without_prefix=True)
|
| 196 |
def main(
|
| 197 |
sources=[
|
| 198 |
-
"/media/CHONK/hugo/spotdl/
|
| 199 |
],
|
| 200 |
output_dir: str = "./samples",
|
| 201 |
max_excerpts: int = 2000,
|
|
|
|
| 119 |
def wrapper(sig, interface):
|
| 120 |
beat_mask = interface.make_beat_mask(
|
| 121 |
sig,
|
| 122 |
+
before_beat_s=ctx_time/2,
|
| 123 |
+
after_beat_s=ctx_time/2,
|
| 124 |
invert=True
|
| 125 |
)
|
| 126 |
+
|
| 127 |
z = interface.encode(sig)
|
| 128 |
+
|
| 129 |
zv = interface.coarse_vamp(
|
| 130 |
+
z, beat_mask
|
| 131 |
)
|
| 132 |
|
| 133 |
zv = interface.coarse_to_fine(zv)
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
EXP_REGISTRY["musical-sampling"] = {
|
|
|
|
|
|
|
|
|
|
| 190 |
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
| 191 |
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
| 192 |
}
|
|
|
|
| 194 |
@argbind.bind(without_prefix=True)
|
| 195 |
def main(
|
| 196 |
sources=[
|
| 197 |
+
"/media/CHONK/hugo/spotdl/val",
|
| 198 |
],
|
| 199 |
output_dir: str = "./samples",
|
| 200 |
max_excerpts: int = 2000,
|
scripts/utils/parallel-gpu.sh
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
# Get the command to execute from the user
|
| 4 |
-
command_to_execute="$1"
|
| 5 |
-
|
| 6 |
-
# Get the maximum number of GPUs to use from the user
|
| 7 |
-
max_gpus="$2"
|
| 8 |
-
|
| 9 |
-
# Get the number of instances to start per GPU from the user
|
| 10 |
-
instances_per_gpu="$3"
|
| 11 |
-
|
| 12 |
-
# Set the CUDA_VISIBLE_DEVICES flag for each GPU
|
| 13 |
-
for gpu_id in $(seq 0 $(($max_gpus - 1))); do
|
| 14 |
-
export CUDA_VISIBLE_DEVICES="$gpu_id"
|
| 15 |
-
# Start the specified number of instances for this GPU
|
| 16 |
-
for i in $(seq 1 "$instances_per_gpu"); do
|
| 17 |
-
# Run the command in the background
|
| 18 |
-
$command_to_execute &
|
| 19 |
-
done
|
| 20 |
-
done
|
| 21 |
-
|
| 22 |
-
# Wait for all instances to finish
|
| 23 |
-
wait
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vampnet/modules/transformer.py
CHANGED
|
@@ -581,7 +581,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 581 |
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
-
temperature: Union[float, Tuple[float, float]] =
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
@@ -592,15 +592,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 592 |
#####################
|
| 593 |
# resolve temperature #
|
| 594 |
#####################
|
| 595 |
-
|
| 596 |
-
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
| 597 |
-
elif isinstance(temperature, tuple):
|
| 598 |
-
assert len(temperature) == 2
|
| 599 |
-
l, h = temperature
|
| 600 |
-
temperature = torch.linspace(l, h, sampling_steps)
|
| 601 |
-
else:
|
| 602 |
-
raise TypeError(f"invalid type for temperature")
|
| 603 |
-
|
| 604 |
logging.debug(f"temperature: {temperature}")
|
| 605 |
|
| 606 |
|
|
@@ -642,10 +634,6 @@ class VampNet(at.ml.BaseModel):
|
|
| 642 |
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
| 643 |
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
| 644 |
|
| 645 |
-
# our r steps
|
| 646 |
-
r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
|
| 647 |
-
logging.debug(f"r steps: {r_steps}")
|
| 648 |
-
|
| 649 |
# how many codebooks are we inferring vs conditioning on?
|
| 650 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
| 651 |
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
|
@@ -658,11 +646,13 @@ class VampNet(at.ml.BaseModel):
|
|
| 658 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 659 |
|
| 660 |
# our current temperature
|
| 661 |
-
|
| 662 |
-
logging.debug(f"temperature: {tmpt}")
|
| 663 |
|
| 664 |
# our current schedule step
|
| 665 |
-
r =
|
|
|
|
|
|
|
|
|
|
| 666 |
logging.debug(f"r: {r}")
|
| 667 |
|
| 668 |
# get latents
|
|
@@ -699,11 +689,18 @@ class VampNet(at.ml.BaseModel):
|
|
| 699 |
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
| 700 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
|
| 703 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 704 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 705 |
# remove conditioning codebooks, we'll add them back at the end
|
| 706 |
-
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
| 707 |
|
| 708 |
mask = (z_masked == self.mask_token).int()
|
| 709 |
|
|
@@ -715,15 +712,6 @@ class VampNet(at.ml.BaseModel):
|
|
| 715 |
)
|
| 716 |
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
| 717 |
|
| 718 |
-
|
| 719 |
-
# get the confidences: which tokens did we sample?
|
| 720 |
-
selected_probs = (
|
| 721 |
-
torch.take_along_dim(
|
| 722 |
-
probs, sampled_z.long().unsqueeze(-1),
|
| 723 |
-
dim=-1
|
| 724 |
-
).squeeze(-1)
|
| 725 |
-
)
|
| 726 |
-
|
| 727 |
# ignore any tokens that weren't masked
|
| 728 |
selected_probs = torch.where(
|
| 729 |
mask.bool(), selected_probs, torch.inf
|
|
@@ -733,18 +721,19 @@ class VampNet(at.ml.BaseModel):
|
|
| 733 |
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
| 734 |
logging.debug(f"num to mask: {num_to_mask}")
|
| 735 |
|
| 736 |
-
|
| 737 |
-
torch.
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
|
|
|
|
|
|
| 741 |
)
|
| 742 |
-
)
|
| 743 |
|
| 744 |
|
| 745 |
# get our new mask
|
| 746 |
mask = mask_by_random_topk(
|
| 747 |
-
num_to_mask, selected_probs,
|
| 748 |
)
|
| 749 |
|
| 750 |
# update the mask
|
|
|
|
| 581 |
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
+
temperature: Union[float, Tuple[float, float]] = 2.5,
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
|
|
| 592 |
#####################
|
| 593 |
# resolve temperature #
|
| 594 |
#####################
|
| 595 |
+
assert isinstance(temperature, float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
logging.debug(f"temperature: {temperature}")
|
| 597 |
|
| 598 |
|
|
|
|
| 634 |
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
| 635 |
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
| 636 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
# how many codebooks are we inferring vs conditioning on?
|
| 638 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
| 639 |
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
|
|
|
| 646 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 647 |
|
| 648 |
# our current temperature
|
| 649 |
+
logging.debug(f"temperature: {temperature}")
|
|
|
|
| 650 |
|
| 651 |
# our current schedule step
|
| 652 |
+
r = scalar_to_batch_tensor(
|
| 653 |
+
(i + 1) / sampling_steps,
|
| 654 |
+
z.shape[0]
|
| 655 |
+
).to(z.device)
|
| 656 |
logging.debug(f"r: {r}")
|
| 657 |
|
| 658 |
# get latents
|
|
|
|
| 689 |
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
| 690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 691 |
|
| 692 |
+
# get the confidences: which tokens did we sample?
|
| 693 |
+
selected_probs = (
|
| 694 |
+
torch.take_along_dim(
|
| 695 |
+
probs, sampled_z.long().unsqueeze(-1),
|
| 696 |
+
dim=-1
|
| 697 |
+
).squeeze(-1)
|
| 698 |
+
)
|
| 699 |
|
| 700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 702 |
# remove conditioning codebooks, we'll add them back at the end
|
| 703 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
| 704 |
|
| 705 |
mask = (z_masked == self.mask_token).int()
|
| 706 |
|
|
|
|
| 712 |
)
|
| 713 |
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
| 714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
# ignore any tokens that weren't masked
|
| 716 |
selected_probs = torch.where(
|
| 717 |
mask.bool(), selected_probs, torch.inf
|
|
|
|
| 721 |
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
| 722 |
logging.debug(f"num to mask: {num_to_mask}")
|
| 723 |
|
| 724 |
+
if i != (sampling_steps - 1):
|
| 725 |
+
num_to_mask = torch.maximum(
|
| 726 |
+
torch.tensor(1),
|
| 727 |
+
torch.minimum(
|
| 728 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
| 729 |
+
num_to_mask
|
| 730 |
+
)
|
| 731 |
)
|
|
|
|
| 732 |
|
| 733 |
|
| 734 |
# get our new mask
|
| 735 |
mask = mask_by_random_topk(
|
| 736 |
+
num_to_mask, selected_probs, temperature * (1-r)
|
| 737 |
)
|
| 738 |
|
| 739 |
# update the mask
|