Spaces:
Runtime error
Runtime error
feat: add lofi-bytes-api and gradio app
Browse files- app.py +157 -0
- model/loss.py +46 -0
- model/music_transformer.py +200 -0
- model/positional_encoding.py +23 -0
- model/rpr.py +464 -0
- processor.py +266 -0
- requirements.txt +6 -0
- uploaded_midis/am_i_blue_jazz.mid +0 -0
- uploaded_midis/ghibli_castle_in_the_sky.mid +0 -0
- utilities/argument_funcs.py +228 -0
- utilities/constants.py +28 -0
- utilities/device.py +73 -0
- utilities/lr_scheduling.py +65 -0
- utilities/run_model.py +95 -0
app.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
from model.music_transformer import MusicTransformer
|
| 8 |
+
from processor import decode_midi, encode_midi
|
| 9 |
+
from utilities.constants import TOKEN_END, TOKEN_PAD, TORCH_LABEL_TYPE
|
| 10 |
+
from utilities.device import get_device, use_cuda
|
| 11 |
+
|
| 12 |
+
REPO_ID = "Launchpad/lofi-bytes"
|
| 13 |
+
FILENAME = "weights_maestro_finetuned.pickle"
|
| 14 |
+
|
| 15 |
+
SEQUENCE_START = 0
|
| 16 |
+
OUTPUT_PATH = "./output_midi"
|
| 17 |
+
RPR = True
|
| 18 |
+
# TARGET_SEQ_LENGTH = 1023
|
| 19 |
+
TARGET_SEQ_LENGTH = 512
|
| 20 |
+
NUM_PRIME = 65
|
| 21 |
+
MAX_SEQUENCE = 2048
|
| 22 |
+
N_LAYERS = 6
|
| 23 |
+
NUM_HEADS = 8
|
| 24 |
+
D_MODEL = 512
|
| 25 |
+
DIM_FEEDFORWARD = 1024
|
| 26 |
+
BEAM = 0
|
| 27 |
+
FORCE_CPU = False
|
| 28 |
+
ALLOWED_EXTENSIONS = {'mid'}
|
| 29 |
+
UPLOAD_FOLDER = './uploaded_midis'
|
| 30 |
+
|
| 31 |
+
generated_midi = None
|
| 32 |
+
|
| 33 |
+
use_cuda(True)
|
| 34 |
+
|
| 35 |
+
model = MusicTransformer(
|
| 36 |
+
n_layers=N_LAYERS,
|
| 37 |
+
num_heads=NUM_HEADS,
|
| 38 |
+
d_model=D_MODEL,
|
| 39 |
+
dim_feedforward=DIM_FEEDFORWARD,
|
| 40 |
+
max_sequence=MAX_SEQUENCE,
|
| 41 |
+
rpr=RPR
|
| 42 |
+
).to(get_device())
|
| 43 |
+
|
| 44 |
+
state_dict = torch.load(
|
| 45 |
+
hf_hub_download(repo_id=REPO_ID, filename=FILENAME),
|
| 46 |
+
map_location=get_device()
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
model.load_state_dict(state_dict)
|
| 50 |
+
|
| 51 |
+
def generate(input_midi):
|
| 52 |
+
|
| 53 |
+
raw_mid = encode_midi(input_midi)
|
| 54 |
+
if(len(raw_mid) == 0):
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
primer, _ = process_midi(raw_mid, NUM_PRIME, random_seq=False)
|
| 58 |
+
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 59 |
+
|
| 60 |
+
# saves a pretty_midi at file_path
|
| 61 |
+
# decode_midi(primer[:NUM_PRIME].cpu().numpy(), file_path=f_path)
|
| 62 |
+
decode_midi(primer[:NUM_PRIME].cpu().numpy())
|
| 63 |
+
|
| 64 |
+
# GENERATION
|
| 65 |
+
model.eval()
|
| 66 |
+
with torch.set_grad_enabled(False):
|
| 67 |
+
|
| 68 |
+
# NOTE: model.generate() returns a MIDI stored as an ARRAY given a primer
|
| 69 |
+
beam_seq = model.generate(primer[:NUM_PRIME], TARGET_SEQ_LENGTH, beam=BEAM)
|
| 70 |
+
|
| 71 |
+
file_path = "output.mid"
|
| 72 |
+
|
| 73 |
+
# NOTE: function decode_midi() returns an actual MIDI of class pretty_midi.PrettyMIDI
|
| 74 |
+
decoded_midi = decode_midi(beam_seq[0].cpu().numpy(), file_path=file_path)
|
| 75 |
+
|
| 76 |
+
# THIS SHOULD BE EITHER decoded_midi OR beam_seq
|
| 77 |
+
# TODO: decoded_midi is actual pretty_midi MIDI file, beam_seq is just an array representing a MIDI
|
| 78 |
+
# decoded_midi stores more information about instruments and stuff
|
| 79 |
+
return file_path
|
| 80 |
+
|
| 81 |
+
def process_midi(raw_mid, max_seq, random_seq):
|
| 82 |
+
"""
|
| 83 |
+
----------
|
| 84 |
+
Author: Damon Gwinn
|
| 85 |
+
----------
|
| 86 |
+
Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
|
| 87 |
+
go from the start based on random_seq.
|
| 88 |
+
----------
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 92 |
+
tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 93 |
+
|
| 94 |
+
raw_len = len(raw_mid)
|
| 95 |
+
full_seq = max_seq + 1 # Performing seq2seq
|
| 96 |
+
|
| 97 |
+
if(raw_len == 0):
|
| 98 |
+
return x, tgt
|
| 99 |
+
|
| 100 |
+
if(raw_len < full_seq):
|
| 101 |
+
x[:raw_len] = raw_mid
|
| 102 |
+
tgt[:raw_len-1] = raw_mid[1:]
|
| 103 |
+
tgt[raw_len] = TOKEN_END
|
| 104 |
+
else:
|
| 105 |
+
# Randomly selecting a range
|
| 106 |
+
if(random_seq):
|
| 107 |
+
end_range = raw_len - full_seq
|
| 108 |
+
start = random.randint(SEQUENCE_START, end_range)
|
| 109 |
+
|
| 110 |
+
# Always taking from the start to as far as we can
|
| 111 |
+
else:
|
| 112 |
+
start = SEQUENCE_START
|
| 113 |
+
|
| 114 |
+
end = start + full_seq
|
| 115 |
+
|
| 116 |
+
data = raw_mid[start:end]
|
| 117 |
+
|
| 118 |
+
x = data[:max_seq]
|
| 119 |
+
tgt = data[1:full_seq]
|
| 120 |
+
|
| 121 |
+
return x, tgt
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
with gr.Blocks() as demo:
|
| 125 |
+
with gr.Row():
|
| 126 |
+
with gr.Column(scale=1):
|
| 127 |
+
gr.Image(
|
| 128 |
+
"https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/410912267_278779401866686_2517511436172822307_n_0iVwDxI.png",
|
| 129 |
+
elem_id="logo-img",
|
| 130 |
+
show_label=False,
|
| 131 |
+
show_share_button=False,
|
| 132 |
+
show_download_button=False,
|
| 133 |
+
show_fullscreen_button=False,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
with gr.Column(scale=3):
|
| 137 |
+
gr.Markdown("""lofi-bytes is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Spring 2023) that generates lofi tracks from input MIDI stamples using a MusicTransformer model.
|
| 138 |
+
<br/><br/>
|
| 139 |
+
**Model**: [lofi-bytes](https://huggingface.co/Launchpad/lofi-bytes)
|
| 140 |
+
<br/>
|
| 141 |
+
**Project Leader**: Alicia Wang
|
| 142 |
+
<br/>
|
| 143 |
+
**Members**: Alena Chao, Eric Liu, Zane Mogannam, Chloe Wong, Iris Zhou
|
| 144 |
+
<br/>
|
| 145 |
+
**Advisors**: Vincent Lim, Winston Liu
|
| 146 |
+
<br/>
|
| 147 |
+
"""
|
| 148 |
+
)
|
| 149 |
+
gr.Interface(
|
| 150 |
+
fn=generate,
|
| 151 |
+
inputs=gr.File(),
|
| 152 |
+
outputs=gr.File(),
|
| 153 |
+
examples=["uploaded_midis/ghibli_castle_in_the_sky.mid", "uploaded_midis/am_i_blue_jazz.mid"]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if __name__ == '__main__':
|
| 157 |
+
demo.launch(share=True)
|
model/loss.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.modules.loss import _Loss
|
| 5 |
+
|
| 6 |
+
# Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28
|
| 7 |
+
class SmoothCrossEntropyLoss(_Loss):
|
| 8 |
+
"""
|
| 9 |
+
https://arxiv.org/abs/1512.00567
|
| 10 |
+
"""
|
| 11 |
+
__constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction']
|
| 12 |
+
|
| 13 |
+
def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True):
|
| 14 |
+
assert 0.0 <= label_smoothing <= 1.0
|
| 15 |
+
super().__init__(reduction=reduction)
|
| 16 |
+
|
| 17 |
+
self.label_smoothing = label_smoothing
|
| 18 |
+
self.vocab_size = vocab_size
|
| 19 |
+
self.ignore_index = ignore_index
|
| 20 |
+
self.input_is_logits = is_logits
|
| 21 |
+
|
| 22 |
+
def forward(self, input, target):
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
input: [B * T, V]
|
| 26 |
+
target: [B * T]
|
| 27 |
+
Returns:
|
| 28 |
+
cross entropy: [1]
|
| 29 |
+
"""
|
| 30 |
+
mask = (target == self.ignore_index).unsqueeze(-1)
|
| 31 |
+
q = F.one_hot(target.long(), self.vocab_size).type(torch.float32)
|
| 32 |
+
u = 1.0 / self.vocab_size
|
| 33 |
+
q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u
|
| 34 |
+
q_prime = q_prime.masked_fill(mask, 0)
|
| 35 |
+
|
| 36 |
+
ce = self.cross_entropy_with_logits(q_prime, input)
|
| 37 |
+
if self.reduction == 'mean':
|
| 38 |
+
lengths = torch.sum(target != self.ignore_index)
|
| 39 |
+
return ce.sum() / lengths
|
| 40 |
+
elif self.reduction == 'sum':
|
| 41 |
+
return ce.sum()
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
def cross_entropy_with_logits(self, p, q):
|
| 46 |
+
return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1)
|
model/music_transformer.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.modules.normalization import LayerNorm
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
from utilities.constants import *
|
| 7 |
+
from utilities.device import get_device
|
| 8 |
+
|
| 9 |
+
from .positional_encoding import PositionalEncoding
|
| 10 |
+
from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# MusicTransformer
|
| 14 |
+
class MusicTransformer(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
----------
|
| 17 |
+
Author: Damon Gwinn
|
| 18 |
+
----------
|
| 19 |
+
Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for
|
| 20 |
+
tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument
|
| 21 |
+
toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155).
|
| 22 |
+
|
| 23 |
+
Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to
|
| 24 |
+
make a decoder-only transformer architecture
|
| 25 |
+
|
| 26 |
+
For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be
|
| 27 |
+
kept up to date with Pytorch revisions only as necessary.
|
| 28 |
+
----------
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
|
| 32 |
+
dropout=0.1, max_sequence=2048, rpr=False):
|
| 33 |
+
super(MusicTransformer, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.dummy = DummyDecoder()
|
| 36 |
+
|
| 37 |
+
self.nlayers = n_layers
|
| 38 |
+
self.nhead = num_heads
|
| 39 |
+
self.d_model = d_model
|
| 40 |
+
self.d_ff = dim_feedforward
|
| 41 |
+
self.dropout = dropout
|
| 42 |
+
self.max_seq = max_sequence
|
| 43 |
+
self.rpr = rpr
|
| 44 |
+
|
| 45 |
+
# Input embedding
|
| 46 |
+
self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)
|
| 47 |
+
|
| 48 |
+
# Positional encoding
|
| 49 |
+
self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
|
| 50 |
+
|
| 51 |
+
# Base transformer
|
| 52 |
+
if(not self.rpr):
|
| 53 |
+
# To make a decoder-only transformer we need to use masked encoder layers
|
| 54 |
+
# Dummy decoder to essentially just return the encoder output
|
| 55 |
+
self.transformer = nn.Transformer(
|
| 56 |
+
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
|
| 57 |
+
num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
|
| 58 |
+
dim_feedforward=self.d_ff, custom_decoder=self.dummy
|
| 59 |
+
)
|
| 60 |
+
# RPR Transformer
|
| 61 |
+
else:
|
| 62 |
+
encoder_norm = LayerNorm(self.d_model)
|
| 63 |
+
encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
|
| 64 |
+
encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
|
| 65 |
+
self.transformer = nn.Transformer(
|
| 66 |
+
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
|
| 67 |
+
num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
|
| 68 |
+
dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Final output is a softmaxed linear layer
|
| 72 |
+
self.Wout = nn.Linear(self.d_model, VOCAB_SIZE)
|
| 73 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 74 |
+
|
| 75 |
+
# forward
|
| 76 |
+
def forward(self, x, mask=True):
|
| 77 |
+
"""
|
| 78 |
+
----------
|
| 79 |
+
Author: Damon Gwinn
|
| 80 |
+
----------
|
| 81 |
+
Takes an input sequence and outputs predictions using a sequence to sequence method.
|
| 82 |
+
|
| 83 |
+
A prediction at one index is the "next" prediction given all information seen previously.
|
| 84 |
+
----------
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
if(mask is True):
|
| 88 |
+
mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device())
|
| 89 |
+
else:
|
| 90 |
+
mask = None
|
| 91 |
+
|
| 92 |
+
x = self.embedding(x)
|
| 93 |
+
|
| 94 |
+
# Input shape is (max_seq, batch_size, d_model)
|
| 95 |
+
x = x.permute(1,0,2)
|
| 96 |
+
|
| 97 |
+
x = self.positional_encoding(x)
|
| 98 |
+
|
| 99 |
+
# Since there are no true decoder layers, the tgt is unused
|
| 100 |
+
# Pytorch wants src and tgt to have some equal dims however
|
| 101 |
+
x_out = self.transformer(src=x, tgt=x, src_mask=mask)
|
| 102 |
+
|
| 103 |
+
# Back to (batch_size, max_seq, d_model)
|
| 104 |
+
x_out = x_out.permute(1,0,2)
|
| 105 |
+
|
| 106 |
+
y = self.Wout(x_out)
|
| 107 |
+
# y = self.softmax(y)
|
| 108 |
+
|
| 109 |
+
del mask
|
| 110 |
+
|
| 111 |
+
# They are trained to predict the next note in sequence (we don't need the last one)
|
| 112 |
+
return y
|
| 113 |
+
|
| 114 |
+
# generate
|
| 115 |
+
def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):
|
| 116 |
+
"""
|
| 117 |
+
----------
|
| 118 |
+
Author: Damon Gwinn
|
| 119 |
+
----------
|
| 120 |
+
Generates midi given a primer sample. Music can be generated using a probability distribution over
|
| 121 |
+
the softmax probabilities (recommended) or by using a beam search.
|
| 122 |
+
----------
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
assert (not self.training), "Cannot generate while in training mode"
|
| 126 |
+
|
| 127 |
+
print("Generating sequence of max length:", target_seq_length)
|
| 128 |
+
|
| 129 |
+
gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
| 130 |
+
|
| 131 |
+
num_primer = len(primer)
|
| 132 |
+
gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# print("primer:",primer)
|
| 136 |
+
# print(gen_seq)
|
| 137 |
+
cur_i = num_primer
|
| 138 |
+
while(cur_i < target_seq_length):
|
| 139 |
+
# gen_seq_batch = gen_seq.clone()
|
| 140 |
+
y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
|
| 141 |
+
token_probs = y[:, cur_i-1, :]
|
| 142 |
+
|
| 143 |
+
if(beam == 0):
|
| 144 |
+
beam_ran = 2.0
|
| 145 |
+
else:
|
| 146 |
+
beam_ran = random.uniform(0,1)
|
| 147 |
+
|
| 148 |
+
if(beam_ran <= beam_chance):
|
| 149 |
+
token_probs = token_probs.flatten()
|
| 150 |
+
top_res, top_i = torch.topk(token_probs, beam)
|
| 151 |
+
|
| 152 |
+
beam_rows = top_i // VOCAB_SIZE
|
| 153 |
+
beam_cols = top_i % VOCAB_SIZE
|
| 154 |
+
|
| 155 |
+
gen_seq = gen_seq[beam_rows, :]
|
| 156 |
+
gen_seq[..., cur_i] = beam_cols
|
| 157 |
+
|
| 158 |
+
else:
|
| 159 |
+
distrib = torch.distributions.categorical.Categorical(probs=token_probs)
|
| 160 |
+
next_token = distrib.sample()
|
| 161 |
+
# print("next token:",next_token)
|
| 162 |
+
gen_seq[:, cur_i] = next_token
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Let the transformer decide to end if it wants to
|
| 166 |
+
if(next_token == TOKEN_END):
|
| 167 |
+
print("Model called end of sequence at:", cur_i, "/", target_seq_length)
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
cur_i += 1
|
| 171 |
+
if(cur_i % 50 == 0):
|
| 172 |
+
print(cur_i, "/", target_seq_length)
|
| 173 |
+
|
| 174 |
+
return gen_seq[:, :cur_i]
|
| 175 |
+
|
| 176 |
+
# Used as a dummy to nn.Transformer
|
| 177 |
+
# DummyDecoder
|
| 178 |
+
class DummyDecoder(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
----------
|
| 181 |
+
Author: Damon Gwinn
|
| 182 |
+
----------
|
| 183 |
+
A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only
|
| 184 |
+
architecture (stacked encoders with dummy decoder fits the bill)
|
| 185 |
+
----------
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self):
|
| 189 |
+
super(DummyDecoder, self).__init__()
|
| 190 |
+
|
| 191 |
+
def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs):
|
| 192 |
+
"""
|
| 193 |
+
----------
|
| 194 |
+
Author: Damon Gwinn
|
| 195 |
+
----------
|
| 196 |
+
Returns the input (memory)
|
| 197 |
+
----------
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
return memory
|
model/positional_encoding.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
# PositionalEncoding
|
| 6 |
+
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
| 7 |
+
class PositionalEncoding(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 10 |
+
super(PositionalEncoding, self).__init__()
|
| 11 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 12 |
+
|
| 13 |
+
pe = torch.zeros(max_len, d_model)
|
| 14 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 18 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 19 |
+
self.register_buffer('pe', pe)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = x + self.pe[:x.size(0), :]
|
| 23 |
+
return self.dropout(x)
|
model/rpr.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torch.nn.parameter import Parameter
|
| 6 |
+
from torch.nn import Module
|
| 7 |
+
from torch.nn.modules.transformer import _get_clones
|
| 8 |
+
from torch.nn.modules.linear import Linear
|
| 9 |
+
from torch.nn.modules.dropout import Dropout
|
| 10 |
+
from torch.nn.modules.normalization import LayerNorm
|
| 11 |
+
from torch.nn.init import *
|
| 12 |
+
|
| 13 |
+
from torch.nn.functional import linear, softmax, dropout
|
| 14 |
+
|
| 15 |
+
# TransformerEncoderRPR
|
| 16 |
+
class TransformerEncoderRPR(Module):
|
| 17 |
+
"""
|
| 18 |
+
----------
|
| 19 |
+
Author: Pytorch
|
| 20 |
+
----------
|
| 21 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
| 22 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoder
|
| 23 |
+
|
| 24 |
+
No modification. Copied here to ensure continued compatibility with other edits.
|
| 25 |
+
----------
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 29 |
+
super(TransformerEncoderRPR, self).__init__()
|
| 30 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 31 |
+
self.num_layers = num_layers
|
| 32 |
+
self.norm = norm
|
| 33 |
+
|
| 34 |
+
def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs):
|
| 35 |
+
|
| 36 |
+
output = src
|
| 37 |
+
|
| 38 |
+
for i in range(self.num_layers):
|
| 39 |
+
output = self.layers[i](output, src_mask=mask,
|
| 40 |
+
src_key_padding_mask=src_key_padding_mask)
|
| 41 |
+
|
| 42 |
+
if self.norm:
|
| 43 |
+
output = self.norm(output)
|
| 44 |
+
|
| 45 |
+
return output
|
| 46 |
+
|
| 47 |
+
# TransformerEncoderLayerRPR
|
| 48 |
+
class TransformerEncoderLayerRPR(Module):
|
| 49 |
+
"""
|
| 50 |
+
----------
|
| 51 |
+
Author: Pytorch
|
| 52 |
+
Modified: Damon Gwinn
|
| 53 |
+
----------
|
| 54 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
| 55 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
| 56 |
+
|
| 57 |
+
Modification to create and call custom MultiheadAttentionRPR
|
| 58 |
+
----------
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
|
| 62 |
+
super(TransformerEncoderLayerRPR, self).__init__()
|
| 63 |
+
self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len)
|
| 64 |
+
# Implementation of Feedforward model
|
| 65 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
| 66 |
+
self.dropout = Dropout(dropout)
|
| 67 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
| 68 |
+
|
| 69 |
+
self.norm1 = LayerNorm(d_model)
|
| 70 |
+
self.norm2 = LayerNorm(d_model)
|
| 71 |
+
self.dropout1 = Dropout(dropout)
|
| 72 |
+
self.dropout2 = Dropout(dropout)
|
| 73 |
+
|
| 74 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 75 |
+
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
|
| 76 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 77 |
+
src = src + self.dropout1(src2)
|
| 78 |
+
src = self.norm1(src)
|
| 79 |
+
src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
|
| 80 |
+
src = src + self.dropout2(src2)
|
| 81 |
+
src = self.norm2(src)
|
| 82 |
+
return src
|
| 83 |
+
|
| 84 |
+
# MultiheadAttentionRPR
|
| 85 |
+
class MultiheadAttentionRPR(Module):
|
| 86 |
+
"""
|
| 87 |
+
----------
|
| 88 |
+
Author: Pytorch
|
| 89 |
+
Modified: Damon Gwinn
|
| 90 |
+
----------
|
| 91 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
| 92 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/activation.html#MultiheadAttention
|
| 93 |
+
|
| 94 |
+
Modification to add RPR embedding Er and call custom multi_head_attention_forward_rpr
|
| 95 |
+
----------
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
|
| 99 |
+
super(MultiheadAttentionRPR, self).__init__()
|
| 100 |
+
self.embed_dim = embed_dim
|
| 101 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 102 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 103 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 104 |
+
|
| 105 |
+
self.num_heads = num_heads
|
| 106 |
+
self.dropout = dropout
|
| 107 |
+
self.head_dim = embed_dim // num_heads
|
| 108 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 109 |
+
|
| 110 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
| 111 |
+
|
| 112 |
+
if self._qkv_same_embed_dim is False:
|
| 113 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 114 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
| 115 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
| 116 |
+
|
| 117 |
+
if bias:
|
| 118 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
| 119 |
+
else:
|
| 120 |
+
self.register_parameter('in_proj_bias', None)
|
| 121 |
+
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
| 122 |
+
|
| 123 |
+
if add_bias_kv:
|
| 124 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
| 125 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
| 126 |
+
else:
|
| 127 |
+
self.bias_k = self.bias_v = None
|
| 128 |
+
|
| 129 |
+
self.add_zero_attn = add_zero_attn
|
| 130 |
+
|
| 131 |
+
# Adding RPR embedding matrix
|
| 132 |
+
if(er_len is not None):
|
| 133 |
+
self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32))
|
| 134 |
+
else:
|
| 135 |
+
self.Er = None
|
| 136 |
+
|
| 137 |
+
self._reset_parameters()
|
| 138 |
+
|
| 139 |
+
def _reset_parameters(self):
|
| 140 |
+
if self._qkv_same_embed_dim:
|
| 141 |
+
xavier_uniform_(self.in_proj_weight)
|
| 142 |
+
else:
|
| 143 |
+
xavier_uniform_(self.q_proj_weight)
|
| 144 |
+
xavier_uniform_(self.k_proj_weight)
|
| 145 |
+
xavier_uniform_(self.v_proj_weight)
|
| 146 |
+
|
| 147 |
+
if self.in_proj_bias is not None:
|
| 148 |
+
constant_(self.in_proj_bias, 0.)
|
| 149 |
+
constant_(self.out_proj.bias, 0.)
|
| 150 |
+
if self.bias_k is not None:
|
| 151 |
+
xavier_normal_(self.bias_k)
|
| 152 |
+
if self.bias_v is not None:
|
| 153 |
+
xavier_normal_(self.bias_v)
|
| 154 |
+
|
| 155 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
| 156 |
+
need_weights=True, attn_mask=None):
|
| 157 |
+
|
| 158 |
+
if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
|
| 159 |
+
# return F.multi_head_attention_forward(
|
| 160 |
+
# query, key, value, self.embed_dim, self.num_heads,
|
| 161 |
+
# self.in_proj_weight, self.in_proj_bias,
|
| 162 |
+
# self.bias_k, self.bias_v, self.add_zero_attn,
|
| 163 |
+
# self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 164 |
+
# training=self.training,
|
| 165 |
+
# key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 166 |
+
# attn_mask=attn_mask, use_separate_proj_weight=True,
|
| 167 |
+
# q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
| 168 |
+
# v_proj_weight=self.v_proj_weight)
|
| 169 |
+
|
| 170 |
+
return multi_head_attention_forward_rpr(
|
| 171 |
+
query, key, value, self.embed_dim, self.num_heads,
|
| 172 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 173 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
| 174 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 175 |
+
training=self.training,
|
| 176 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 177 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
| 178 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
| 179 |
+
v_proj_weight=self.v_proj_weight, rpr_mat=self.Er)
|
| 180 |
+
else:
|
| 181 |
+
if not hasattr(self, '_qkv_same_embed_dim'):
|
| 182 |
+
warnings.warn('A new version of MultiheadAttention module has been implemented. \
|
| 183 |
+
Please re-train your model with the new module',
|
| 184 |
+
UserWarning)
|
| 185 |
+
|
| 186 |
+
# return F.multi_head_attention_forward(
|
| 187 |
+
# query, key, value, self.embed_dim, self.num_heads,
|
| 188 |
+
# self.in_proj_weight, self.in_proj_bias,
|
| 189 |
+
# self.bias_k, self.bias_v, self.add_zero_attn,
|
| 190 |
+
# self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 191 |
+
# training=self.training,
|
| 192 |
+
# key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 193 |
+
# attn_mask=attn_mask)
|
| 194 |
+
|
| 195 |
+
return multi_head_attention_forward_rpr(
|
| 196 |
+
query, key, value, self.embed_dim, self.num_heads,
|
| 197 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 198 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
| 199 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 200 |
+
training=self.training,
|
| 201 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 202 |
+
attn_mask=attn_mask, rpr_mat=self.Er)
|
| 203 |
+
|
| 204 |
+
# multi_head_attention_forward_rpr
|
| 205 |
+
def multi_head_attention_forward_rpr(query, # type: Tensor
|
| 206 |
+
key, # type: Tensor
|
| 207 |
+
value, # type: Tensor
|
| 208 |
+
embed_dim_to_check, # type: int
|
| 209 |
+
num_heads, # type: int
|
| 210 |
+
in_proj_weight, # type: Tensor
|
| 211 |
+
in_proj_bias, # type: Tensor
|
| 212 |
+
bias_k, # type: Optional[Tensor]
|
| 213 |
+
bias_v, # type: Optional[Tensor]
|
| 214 |
+
add_zero_attn, # type: bool
|
| 215 |
+
dropout_p, # type: float
|
| 216 |
+
out_proj_weight, # type: Tensor
|
| 217 |
+
out_proj_bias, # type: Tensor
|
| 218 |
+
training=True, # type: bool
|
| 219 |
+
key_padding_mask=None, # type: Optional[Tensor]
|
| 220 |
+
need_weights=True, # type: bool
|
| 221 |
+
attn_mask=None, # type: Optional[Tensor]
|
| 222 |
+
use_separate_proj_weight=False, # type: bool
|
| 223 |
+
q_proj_weight=None, # type: Optional[Tensor]
|
| 224 |
+
k_proj_weight=None, # type: Optional[Tensor]
|
| 225 |
+
v_proj_weight=None, # type: Optional[Tensor]
|
| 226 |
+
static_k=None, # type: Optional[Tensor]
|
| 227 |
+
static_v=None, # type: Optional[Tensor]
|
| 228 |
+
rpr_mat=None
|
| 229 |
+
):
|
| 230 |
+
"""
|
| 231 |
+
----------
|
| 232 |
+
Author: Pytorch
|
| 233 |
+
Modified: Damon Gwinn
|
| 234 |
+
----------
|
| 235 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
| 236 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html
|
| 237 |
+
|
| 238 |
+
Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281)
|
| 239 |
+
----------
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
| 243 |
+
|
| 244 |
+
qkv_same = torch.equal(query, key) and torch.equal(key, value)
|
| 245 |
+
kv_same = torch.equal(key, value)
|
| 246 |
+
|
| 247 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 248 |
+
assert embed_dim == embed_dim_to_check
|
| 249 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 250 |
+
assert key.size() == value.size()
|
| 251 |
+
|
| 252 |
+
head_dim = embed_dim // num_heads
|
| 253 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
| 254 |
+
scaling = float(head_dim) ** -0.5
|
| 255 |
+
|
| 256 |
+
if use_separate_proj_weight is not True:
|
| 257 |
+
if qkv_same:
|
| 258 |
+
# self-attention
|
| 259 |
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
| 260 |
+
|
| 261 |
+
elif kv_same:
|
| 262 |
+
# encoder-decoder attention
|
| 263 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 264 |
+
_b = in_proj_bias
|
| 265 |
+
_start = 0
|
| 266 |
+
_end = embed_dim
|
| 267 |
+
_w = in_proj_weight[_start:_end, :]
|
| 268 |
+
if _b is not None:
|
| 269 |
+
_b = _b[_start:_end]
|
| 270 |
+
q = linear(query, _w, _b)
|
| 271 |
+
|
| 272 |
+
if key is None:
|
| 273 |
+
assert value is None
|
| 274 |
+
k = None
|
| 275 |
+
v = None
|
| 276 |
+
else:
|
| 277 |
+
|
| 278 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 279 |
+
_b = in_proj_bias
|
| 280 |
+
_start = embed_dim
|
| 281 |
+
_end = None
|
| 282 |
+
_w = in_proj_weight[_start:, :]
|
| 283 |
+
if _b is not None:
|
| 284 |
+
_b = _b[_start:]
|
| 285 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
| 286 |
+
|
| 287 |
+
else:
|
| 288 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 289 |
+
_b = in_proj_bias
|
| 290 |
+
_start = 0
|
| 291 |
+
_end = embed_dim
|
| 292 |
+
_w = in_proj_weight[_start:_end, :]
|
| 293 |
+
if _b is not None:
|
| 294 |
+
_b = _b[_start:_end]
|
| 295 |
+
q = linear(query, _w, _b)
|
| 296 |
+
|
| 297 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 298 |
+
_b = in_proj_bias
|
| 299 |
+
_start = embed_dim
|
| 300 |
+
_end = embed_dim * 2
|
| 301 |
+
_w = in_proj_weight[_start:_end, :]
|
| 302 |
+
if _b is not None:
|
| 303 |
+
_b = _b[_start:_end]
|
| 304 |
+
k = linear(key, _w, _b)
|
| 305 |
+
|
| 306 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 307 |
+
_b = in_proj_bias
|
| 308 |
+
_start = embed_dim * 2
|
| 309 |
+
_end = None
|
| 310 |
+
_w = in_proj_weight[_start:, :]
|
| 311 |
+
if _b is not None:
|
| 312 |
+
_b = _b[_start:]
|
| 313 |
+
v = linear(value, _w, _b)
|
| 314 |
+
else:
|
| 315 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
| 316 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
| 317 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
| 318 |
+
|
| 319 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
| 320 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
| 321 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
| 322 |
+
|
| 323 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
| 324 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
| 325 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
| 326 |
+
|
| 327 |
+
if in_proj_bias is not None:
|
| 328 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
| 329 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
| 330 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
| 331 |
+
else:
|
| 332 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
| 333 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
| 334 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
| 335 |
+
q = q * scaling
|
| 336 |
+
|
| 337 |
+
if bias_k is not None and bias_v is not None:
|
| 338 |
+
if static_k is None and static_v is None:
|
| 339 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
| 340 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
| 341 |
+
if attn_mask is not None:
|
| 342 |
+
attn_mask = torch.cat([attn_mask,
|
| 343 |
+
torch.zeros((attn_mask.size(0), 1),
|
| 344 |
+
dtype=attn_mask.dtype,
|
| 345 |
+
device=attn_mask.device)], dim=1)
|
| 346 |
+
if key_padding_mask is not None:
|
| 347 |
+
key_padding_mask = torch.cat(
|
| 348 |
+
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
| 349 |
+
dtype=key_padding_mask.dtype,
|
| 350 |
+
device=key_padding_mask.device)], dim=1)
|
| 351 |
+
else:
|
| 352 |
+
assert static_k is None, "bias cannot be added to static key."
|
| 353 |
+
assert static_v is None, "bias cannot be added to static value."
|
| 354 |
+
else:
|
| 355 |
+
assert bias_k is None
|
| 356 |
+
assert bias_v is None
|
| 357 |
+
|
| 358 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
| 359 |
+
if k is not None:
|
| 360 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
| 361 |
+
if v is not None:
|
| 362 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
| 363 |
+
|
| 364 |
+
if static_k is not None:
|
| 365 |
+
assert static_k.size(0) == bsz * num_heads
|
| 366 |
+
assert static_k.size(2) == head_dim
|
| 367 |
+
k = static_k
|
| 368 |
+
|
| 369 |
+
if static_v is not None:
|
| 370 |
+
assert static_v.size(0) == bsz * num_heads
|
| 371 |
+
assert static_v.size(2) == head_dim
|
| 372 |
+
v = static_v
|
| 373 |
+
|
| 374 |
+
src_len = k.size(1)
|
| 375 |
+
|
| 376 |
+
if key_padding_mask is not None:
|
| 377 |
+
assert key_padding_mask.size(0) == bsz
|
| 378 |
+
assert key_padding_mask.size(1) == src_len
|
| 379 |
+
|
| 380 |
+
if add_zero_attn:
|
| 381 |
+
src_len += 1
|
| 382 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
| 383 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
| 384 |
+
if attn_mask is not None:
|
| 385 |
+
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
|
| 386 |
+
dtype=attn_mask.dtype,
|
| 387 |
+
device=attn_mask.device)], dim=1)
|
| 388 |
+
if key_padding_mask is not None:
|
| 389 |
+
key_padding_mask = torch.cat(
|
| 390 |
+
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
| 391 |
+
dtype=key_padding_mask.dtype,
|
| 392 |
+
device=key_padding_mask.device)], dim=1)
|
| 393 |
+
|
| 394 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
| 395 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
| 396 |
+
|
| 397 |
+
######### ADDITION OF RPR ###########
|
| 398 |
+
if(rpr_mat is not None):
|
| 399 |
+
rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1])
|
| 400 |
+
qe = torch.einsum("hld,md->hlm", q, rpr_mat)
|
| 401 |
+
srel = _skew(qe)
|
| 402 |
+
|
| 403 |
+
attn_output_weights += srel
|
| 404 |
+
|
| 405 |
+
if attn_mask is not None:
|
| 406 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 407 |
+
attn_output_weights += attn_mask
|
| 408 |
+
|
| 409 |
+
if key_padding_mask is not None:
|
| 410 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
| 411 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
| 412 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| 413 |
+
float('-inf'),
|
| 414 |
+
)
|
| 415 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
| 416 |
+
|
| 417 |
+
attn_output_weights = softmax(
|
| 418 |
+
attn_output_weights, dim=-1)
|
| 419 |
+
|
| 420 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
| 421 |
+
|
| 422 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
| 423 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
| 424 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 425 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
| 426 |
+
|
| 427 |
+
if need_weights:
|
| 428 |
+
# average attention weights over heads
|
| 429 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
| 430 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
| 431 |
+
else:
|
| 432 |
+
return attn_output, None
|
| 433 |
+
|
| 434 |
+
def _get_valid_embedding(Er, len_q, len_k):
|
| 435 |
+
"""
|
| 436 |
+
----------
|
| 437 |
+
Author: Damon Gwinn
|
| 438 |
+
----------
|
| 439 |
+
Gets valid embeddings based on max length of RPR attention
|
| 440 |
+
----------
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
len_e = Er.shape[0]
|
| 444 |
+
start = max(0, len_e - len_q)
|
| 445 |
+
return Er[start:, :]
|
| 446 |
+
|
| 447 |
+
def _skew(qe):
|
| 448 |
+
"""
|
| 449 |
+
----------
|
| 450 |
+
Author: Damon Gwinn
|
| 451 |
+
----------
|
| 452 |
+
Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281)
|
| 453 |
+
----------
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
sz = qe.shape[1]
|
| 457 |
+
mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0)
|
| 458 |
+
|
| 459 |
+
qe = mask * qe
|
| 460 |
+
qe = F.pad(qe, (1,0, 0,0, 0,0))
|
| 461 |
+
qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1]))
|
| 462 |
+
|
| 463 |
+
srel = qe[:, 1:, :]
|
| 464 |
+
return srel
|
processor.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pretty_midi
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
RANGE_NOTE_ON = 128
|
| 5 |
+
RANGE_NOTE_OFF = 128
|
| 6 |
+
RANGE_VEL = 32
|
| 7 |
+
RANGE_TIME_SHIFT = 100
|
| 8 |
+
|
| 9 |
+
START_IDX = {
|
| 10 |
+
'note_on': 0,
|
| 11 |
+
'note_off': RANGE_NOTE_ON,
|
| 12 |
+
'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
|
| 13 |
+
'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SustainAdapter:
|
| 18 |
+
def __init__(self, time, type):
|
| 19 |
+
self.start = time
|
| 20 |
+
self.type = type
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SustainDownManager:
|
| 24 |
+
def __init__(self, start, end):
|
| 25 |
+
self.start = start
|
| 26 |
+
self.end = end
|
| 27 |
+
self.managed_notes = []
|
| 28 |
+
self._note_dict = {} # key: pitch, value: note.start
|
| 29 |
+
|
| 30 |
+
def add_managed_note(self, note: pretty_midi.Note):
|
| 31 |
+
self.managed_notes.append(note)
|
| 32 |
+
|
| 33 |
+
def transposition_notes(self):
|
| 34 |
+
for note in reversed(self.managed_notes):
|
| 35 |
+
try:
|
| 36 |
+
note.end = self._note_dict[note.pitch]
|
| 37 |
+
except KeyError:
|
| 38 |
+
note.end = max(self.end, note.end)
|
| 39 |
+
self._note_dict[note.pitch] = note.start
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Divided note by note_on, note_off
|
| 43 |
+
class SplitNote:
|
| 44 |
+
def __init__(self, type, time, value, velocity):
|
| 45 |
+
## type: note_on, note_off
|
| 46 |
+
self.type = type
|
| 47 |
+
self.time = time
|
| 48 |
+
self.velocity = velocity
|
| 49 |
+
self.value = value
|
| 50 |
+
|
| 51 |
+
def __repr__(self):
|
| 52 |
+
return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
|
| 53 |
+
.format(self.time, self.type, self.value, self.velocity)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Event:
|
| 57 |
+
def __init__(self, event_type, value):
|
| 58 |
+
self.type = event_type
|
| 59 |
+
self.value = value
|
| 60 |
+
|
| 61 |
+
def __repr__(self):
|
| 62 |
+
return '<Event type: {}, value: {}>'.format(self.type, self.value)
|
| 63 |
+
|
| 64 |
+
def to_int(self):
|
| 65 |
+
return START_IDX[self.type] + self.value
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def from_int(int_value):
|
| 69 |
+
info = Event._type_check(int_value)
|
| 70 |
+
return Event(info['type'], info['value'])
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def _type_check(int_value):
|
| 74 |
+
range_note_on = range(0, RANGE_NOTE_ON)
|
| 75 |
+
range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
|
| 76 |
+
range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)
|
| 77 |
+
|
| 78 |
+
valid_value = int_value
|
| 79 |
+
|
| 80 |
+
if int_value in range_note_on:
|
| 81 |
+
return {'type': 'note_on', 'value': valid_value}
|
| 82 |
+
elif int_value in range_note_off:
|
| 83 |
+
valid_value -= RANGE_NOTE_ON
|
| 84 |
+
return {'type': 'note_off', 'value': valid_value}
|
| 85 |
+
elif int_value in range_time_shift:
|
| 86 |
+
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
|
| 87 |
+
return {'type': 'time_shift', 'value': valid_value}
|
| 88 |
+
else:
|
| 89 |
+
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
|
| 90 |
+
return {'type': 'velocity', 'value': valid_value}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _divide_note(notes):
|
| 94 |
+
result_array = []
|
| 95 |
+
notes.sort(key=lambda x: x.start)
|
| 96 |
+
|
| 97 |
+
for note in notes:
|
| 98 |
+
on = SplitNote('note_on', note.start, note.pitch, note.velocity)
|
| 99 |
+
off = SplitNote('note_off', note.end, note.pitch, None)
|
| 100 |
+
result_array += [on, off]
|
| 101 |
+
return result_array
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _merge_note(snote_sequence):
|
| 105 |
+
note_on_dict = {}
|
| 106 |
+
result_array = []
|
| 107 |
+
|
| 108 |
+
for snote in snote_sequence:
|
| 109 |
+
# print(note_on_dict)
|
| 110 |
+
if snote.type == 'note_on':
|
| 111 |
+
note_on_dict[snote.value] = snote
|
| 112 |
+
elif snote.type == 'note_off':
|
| 113 |
+
try:
|
| 114 |
+
on = note_on_dict[snote.value]
|
| 115 |
+
off = snote
|
| 116 |
+
if off.time - on.time == 0:
|
| 117 |
+
continue
|
| 118 |
+
result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
|
| 119 |
+
result_array.append(result)
|
| 120 |
+
except:
|
| 121 |
+
print('info removed pitch: {}'.format(snote.value))
|
| 122 |
+
return result_array
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _snote2events(snote: SplitNote, prev_vel: int):
|
| 126 |
+
result = []
|
| 127 |
+
if snote.velocity is not None:
|
| 128 |
+
modified_velocity = snote.velocity // 4
|
| 129 |
+
if prev_vel != modified_velocity:
|
| 130 |
+
result.append(Event(event_type='velocity', value=modified_velocity))
|
| 131 |
+
result.append(Event(event_type=snote.type, value=snote.value))
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _event_seq2snote_seq(event_sequence):
|
| 136 |
+
timeline = 0
|
| 137 |
+
velocity = 0
|
| 138 |
+
snote_seq = []
|
| 139 |
+
|
| 140 |
+
for event in event_sequence:
|
| 141 |
+
if event.type == 'time_shift':
|
| 142 |
+
timeline += ((event.value+1) / 100)
|
| 143 |
+
if event.type == 'velocity':
|
| 144 |
+
velocity = event.value * 4
|
| 145 |
+
else:
|
| 146 |
+
snote = SplitNote(event.type, timeline, event.value, velocity)
|
| 147 |
+
snote_seq.append(snote)
|
| 148 |
+
return snote_seq
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _make_time_sift_events(prev_time, post_time):
|
| 152 |
+
time_interval = int(round((post_time - prev_time) * 100))
|
| 153 |
+
results = []
|
| 154 |
+
while time_interval >= RANGE_TIME_SHIFT:
|
| 155 |
+
results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
|
| 156 |
+
time_interval -= RANGE_TIME_SHIFT
|
| 157 |
+
if time_interval == 0:
|
| 158 |
+
return results
|
| 159 |
+
else:
|
| 160 |
+
return results + [Event(event_type='time_shift', value=time_interval-1)]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _control_preprocess(ctrl_changes):
|
| 164 |
+
sustains = []
|
| 165 |
+
|
| 166 |
+
manager = None
|
| 167 |
+
for ctrl in ctrl_changes:
|
| 168 |
+
if ctrl.value >= 64 and manager is None:
|
| 169 |
+
# sustain down
|
| 170 |
+
manager = SustainDownManager(start=ctrl.time, end=None)
|
| 171 |
+
elif ctrl.value < 64 and manager is not None:
|
| 172 |
+
# sustain up
|
| 173 |
+
manager.end = ctrl.time
|
| 174 |
+
sustains.append(manager)
|
| 175 |
+
manager = None
|
| 176 |
+
elif ctrl.value < 64 and len(sustains) > 0:
|
| 177 |
+
sustains[-1].end = ctrl.time
|
| 178 |
+
return sustains
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _note_preprocess(susteins, notes):
|
| 182 |
+
note_stream = []
|
| 183 |
+
|
| 184 |
+
if susteins: # if the midi file has sustain controls
|
| 185 |
+
for sustain in susteins:
|
| 186 |
+
for note_idx, note in enumerate(notes):
|
| 187 |
+
if note.start < sustain.start:
|
| 188 |
+
note_stream.append(note)
|
| 189 |
+
elif note.start > sustain.end:
|
| 190 |
+
notes = notes[note_idx:]
|
| 191 |
+
sustain.transposition_notes()
|
| 192 |
+
break
|
| 193 |
+
else:
|
| 194 |
+
sustain.add_managed_note(note)
|
| 195 |
+
|
| 196 |
+
for sustain in susteins:
|
| 197 |
+
note_stream += sustain.managed_notes
|
| 198 |
+
|
| 199 |
+
else: # else, just push everything into note stream
|
| 200 |
+
for note_idx, note in enumerate(notes):
|
| 201 |
+
note_stream.append(note)
|
| 202 |
+
|
| 203 |
+
note_stream.sort(key= lambda x: x.start)
|
| 204 |
+
return note_stream
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def encode_midi(file_path):
|
| 208 |
+
events = []
|
| 209 |
+
notes = []
|
| 210 |
+
mid = pretty_midi.PrettyMIDI(midi_file=file_path)
|
| 211 |
+
|
| 212 |
+
for inst in mid.instruments:
|
| 213 |
+
inst_notes = inst.notes
|
| 214 |
+
# ctrl.number is the number of sustain control. If you want to know abour the number type of control,
|
| 215 |
+
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2
|
| 216 |
+
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
|
| 217 |
+
notes += _note_preprocess(ctrls, inst_notes)
|
| 218 |
+
|
| 219 |
+
dnotes = _divide_note(notes)
|
| 220 |
+
|
| 221 |
+
# print(dnotes)
|
| 222 |
+
dnotes.sort(key=lambda x: x.time)
|
| 223 |
+
# print('sorted:')
|
| 224 |
+
# print(dnotes)
|
| 225 |
+
cur_time = 0
|
| 226 |
+
cur_vel = 0
|
| 227 |
+
for snote in dnotes:
|
| 228 |
+
events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
|
| 229 |
+
events += _snote2events(snote=snote, prev_vel=cur_vel)
|
| 230 |
+
# events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
|
| 231 |
+
|
| 232 |
+
cur_time = snote.time
|
| 233 |
+
cur_vel = snote.velocity
|
| 234 |
+
|
| 235 |
+
return [e.to_int() for e in events]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def decode_midi(idx_array, file_path=None):
|
| 239 |
+
event_sequence = [Event.from_int(idx) for idx in idx_array]
|
| 240 |
+
# print(event_sequence)
|
| 241 |
+
snote_seq = _event_seq2snote_seq(event_sequence)
|
| 242 |
+
note_seq = _merge_note(snote_seq)
|
| 243 |
+
note_seq.sort(key=lambda x:x.start)
|
| 244 |
+
|
| 245 |
+
mid = pretty_midi.PrettyMIDI()
|
| 246 |
+
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
|
| 247 |
+
instument = pretty_midi.Instrument(0, False, "Composed by Super Piano Music Transformer AI")
|
| 248 |
+
instument.notes = note_seq
|
| 249 |
+
|
| 250 |
+
mid.instruments.append(instument)
|
| 251 |
+
if file_path is not None:
|
| 252 |
+
mid.write(file_path)
|
| 253 |
+
return mid
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
encoded = encode_midi('bin/ADIG04.mid')
|
| 258 |
+
print(encoded)
|
| 259 |
+
decided = decode_midi(encoded,file_path='bin/test.mid')
|
| 260 |
+
|
| 261 |
+
ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid')
|
| 262 |
+
print(ins)
|
| 263 |
+
print(ins.instruments[0])
|
| 264 |
+
for i in ins.instruments:
|
| 265 |
+
print(i.control_changes)
|
| 266 |
+
print(i.notes)
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
huggingface_hub
|
| 3 |
+
pretty_midi
|
| 4 |
+
setuptools
|
| 5 |
+
spaces
|
| 6 |
+
torch
|
uploaded_midis/am_i_blue_jazz.mid
ADDED
|
Binary file (21.5 kB). View file
|
|
|
uploaded_midis/ghibli_castle_in_the_sky.mid
ADDED
|
Binary file (2.81 kB). View file
|
|
|
utilities/argument_funcs.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from .constants import SEPERATOR
|
| 4 |
+
|
| 5 |
+
# parse_train_args
|
| 6 |
+
def parse_train_args():
|
| 7 |
+
"""
|
| 8 |
+
----------
|
| 9 |
+
Author: Damon Gwinn
|
| 10 |
+
----------
|
| 11 |
+
Argparse arguments for training a model
|
| 12 |
+
----------
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
|
| 17 |
+
parser.add_argument("-input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
|
| 18 |
+
parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch")
|
| 19 |
+
parser.add_argument("-weight_modulus", type=int, default=1, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)")
|
| 20 |
+
parser.add_argument("-print_modulus", type=int, default=1, help="How often to print train results for a batch (batch loss, learn rate, etc.)")
|
| 21 |
+
|
| 22 |
+
parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
|
| 23 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
| 24 |
+
parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting")
|
| 25 |
+
|
| 26 |
+
parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on")
|
| 27 |
+
parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at")
|
| 28 |
+
|
| 29 |
+
parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.")
|
| 30 |
+
parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)")
|
| 31 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
|
| 32 |
+
parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use")
|
| 33 |
+
|
| 34 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
| 35 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
|
| 36 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
| 37 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
| 38 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
| 39 |
+
|
| 40 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
| 41 |
+
|
| 42 |
+
parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate")
|
| 43 |
+
|
| 44 |
+
return parser.parse_args()
|
| 45 |
+
|
| 46 |
+
# print_train_args
|
| 47 |
+
def print_train_args(args):
|
| 48 |
+
"""
|
| 49 |
+
----------
|
| 50 |
+
Author: Damon Gwinn
|
| 51 |
+
----------
|
| 52 |
+
Prints training arguments
|
| 53 |
+
----------
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
print(SEPERATOR)
|
| 57 |
+
print("input_dir:", args.input_dir)
|
| 58 |
+
print("output_dir:", args.output_dir)
|
| 59 |
+
print("weight_modulus:", args.weight_modulus)
|
| 60 |
+
print("print_modulus:", args.print_modulus)
|
| 61 |
+
print("")
|
| 62 |
+
print("n_workers:", args.n_workers)
|
| 63 |
+
print("force_cpu:", args.force_cpu)
|
| 64 |
+
print("tensorboard:", not args.no_tensorboard)
|
| 65 |
+
print("")
|
| 66 |
+
print("continue_weights:", args.continue_weights)
|
| 67 |
+
print("continue_epoch:", args.continue_epoch)
|
| 68 |
+
print("")
|
| 69 |
+
print("lr:", args.lr)
|
| 70 |
+
print("ce_smoothing:", args.ce_smoothing)
|
| 71 |
+
print("batch_size:", args.batch_size)
|
| 72 |
+
print("epochs:", args.epochs)
|
| 73 |
+
print("")
|
| 74 |
+
print("rpr:", args.rpr)
|
| 75 |
+
print("max_sequence:", args.max_sequence)
|
| 76 |
+
print("n_layers:", args.n_layers)
|
| 77 |
+
print("num_heads:", args.num_heads)
|
| 78 |
+
print("d_model:", args.d_model)
|
| 79 |
+
print("")
|
| 80 |
+
print("dim_feedforward:", args.dim_feedforward)
|
| 81 |
+
print("dropout:", args.dropout)
|
| 82 |
+
print(SEPERATOR)
|
| 83 |
+
print("")
|
| 84 |
+
|
| 85 |
+
# parse_eval_args
|
| 86 |
+
def parse_eval_args():
|
| 87 |
+
"""
|
| 88 |
+
----------
|
| 89 |
+
Author: Damon Gwinn
|
| 90 |
+
----------
|
| 91 |
+
Argparse arguments for evaluating a model
|
| 92 |
+
----------
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
|
| 97 |
+
parser.add_argument("-dataset_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
|
| 98 |
+
parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
|
| 99 |
+
parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
|
| 100 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
| 101 |
+
|
| 102 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
|
| 103 |
+
|
| 104 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
| 105 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model")
|
| 106 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
| 107 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
| 108 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
| 109 |
+
|
| 110 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
| 111 |
+
|
| 112 |
+
return parser.parse_args()
|
| 113 |
+
|
| 114 |
+
# print_eval_args
|
| 115 |
+
def print_eval_args(args):
|
| 116 |
+
"""
|
| 117 |
+
----------
|
| 118 |
+
Author: Damon Gwinn
|
| 119 |
+
----------
|
| 120 |
+
Prints evaluation arguments
|
| 121 |
+
----------
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
print(SEPERATOR)
|
| 125 |
+
print("dataset_dir:", args.dataset_dir)
|
| 126 |
+
print("model_weights:", args.model_weights)
|
| 127 |
+
print("n_workers:", args.n_workers)
|
| 128 |
+
print("force_cpu:", args.force_cpu)
|
| 129 |
+
print("")
|
| 130 |
+
print("batch_size:", args.batch_size)
|
| 131 |
+
print("")
|
| 132 |
+
print("rpr:", args.rpr)
|
| 133 |
+
print("max_sequence:", args.max_sequence)
|
| 134 |
+
print("n_layers:", args.n_layers)
|
| 135 |
+
print("num_heads:", args.num_heads)
|
| 136 |
+
print("d_model:", args.d_model)
|
| 137 |
+
print("")
|
| 138 |
+
print("dim_feedforward:", args.dim_feedforward)
|
| 139 |
+
print(SEPERATOR)
|
| 140 |
+
print("")
|
| 141 |
+
|
| 142 |
+
# parse_generate_args
|
| 143 |
+
def parse_generate_args():
|
| 144 |
+
"""
|
| 145 |
+
----------
|
| 146 |
+
Author: Damon Gwinn
|
| 147 |
+
----------
|
| 148 |
+
Argparse arguments for generation
|
| 149 |
+
----------
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
parser = argparse.ArgumentParser()
|
| 153 |
+
|
| 154 |
+
parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with")
|
| 155 |
+
parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to")
|
| 156 |
+
parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.")
|
| 157 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
| 158 |
+
|
| 159 |
+
parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be")
|
| 160 |
+
parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with")
|
| 161 |
+
parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
|
| 162 |
+
parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy")
|
| 163 |
+
|
| 164 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
| 165 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
|
| 166 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
| 167 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
| 168 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
| 169 |
+
|
| 170 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
| 171 |
+
|
| 172 |
+
return parser.parse_args()
|
| 173 |
+
|
| 174 |
+
# print_generate_args
|
| 175 |
+
def print_generate_args(args):
|
| 176 |
+
"""
|
| 177 |
+
----------
|
| 178 |
+
Author: Damon Gwinn
|
| 179 |
+
----------
|
| 180 |
+
Prints generation arguments
|
| 181 |
+
----------
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
print(SEPERATOR)
|
| 185 |
+
print("midi_root:", args.midi_root)
|
| 186 |
+
print("output_dir:", args.output_dir)
|
| 187 |
+
print("primer_file:", args.primer_file)
|
| 188 |
+
print("force_cpu:", args.force_cpu)
|
| 189 |
+
print("")
|
| 190 |
+
print("target_seq_length:", args.target_seq_length)
|
| 191 |
+
print("num_prime:", args.num_prime)
|
| 192 |
+
print("model_weights:", args.model_weights)
|
| 193 |
+
print("beam:", args.beam)
|
| 194 |
+
print("")
|
| 195 |
+
print("rpr:", args.rpr)
|
| 196 |
+
print("max_sequence:", args.max_sequence)
|
| 197 |
+
print("n_layers:", args.n_layers)
|
| 198 |
+
print("num_heads:", args.num_heads)
|
| 199 |
+
print("d_model:", args.d_model)
|
| 200 |
+
print("")
|
| 201 |
+
print("dim_feedforward:", args.dim_feedforward)
|
| 202 |
+
print(SEPERATOR)
|
| 203 |
+
print("")
|
| 204 |
+
|
| 205 |
+
# write_model_params
|
| 206 |
+
def write_model_params(args, output_file):
|
| 207 |
+
"""
|
| 208 |
+
----------
|
| 209 |
+
Author: Damon Gwinn
|
| 210 |
+
----------
|
| 211 |
+
Writes given training parameters to text file
|
| 212 |
+
----------
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
o_stream = open(output_file, "w")
|
| 216 |
+
|
| 217 |
+
o_stream.write("rpr: " + str(args.rpr) + "\n")
|
| 218 |
+
o_stream.write("lr: " + str(args.lr) + "\n")
|
| 219 |
+
o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n")
|
| 220 |
+
o_stream.write("batch_size: " + str(args.batch_size) + "\n")
|
| 221 |
+
o_stream.write("max_sequence: " + str(args.max_sequence) + "\n")
|
| 222 |
+
o_stream.write("n_layers: " + str(args.n_layers) + "\n")
|
| 223 |
+
o_stream.write("num_heads: " + str(args.num_heads) + "\n")
|
| 224 |
+
o_stream.write("d_model: " + str(args.d_model) + "\n")
|
| 225 |
+
o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n")
|
| 226 |
+
o_stream.write("dropout: " + str(args.dropout) + "\n")
|
| 227 |
+
|
| 228 |
+
o_stream.close()
|
utilities/constants.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT
|
| 4 |
+
|
| 5 |
+
SEPERATOR = "========================="
|
| 6 |
+
|
| 7 |
+
# Taken from the paper
|
| 8 |
+
ADAM_BETA_1 = 0.9
|
| 9 |
+
ADAM_BETA_2 = 0.98
|
| 10 |
+
ADAM_EPSILON = 10e-9
|
| 11 |
+
|
| 12 |
+
LR_DEFAULT_START = 1.0
|
| 13 |
+
SCHEDULER_WARMUP_STEPS = 4000
|
| 14 |
+
# LABEL_SMOOTHING_E = 0.1
|
| 15 |
+
|
| 16 |
+
# DROPOUT_P = 0.1
|
| 17 |
+
|
| 18 |
+
TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
|
| 19 |
+
TOKEN_PAD = TOKEN_END + 1
|
| 20 |
+
|
| 21 |
+
VOCAB_SIZE = TOKEN_PAD + 1
|
| 22 |
+
|
| 23 |
+
TORCH_FLOAT = torch.float32
|
| 24 |
+
TORCH_INT = torch.int32
|
| 25 |
+
|
| 26 |
+
TORCH_LABEL_TYPE = torch.long
|
| 27 |
+
|
| 28 |
+
PREPEND_ZEROS_WIDTH = 4
|
utilities/device.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# For all things related to devices
|
| 2 |
+
#### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# change cuda devices to ones that are available after running nvidia-smi.
|
| 8 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '3,4,5'
|
| 9 |
+
|
| 10 |
+
TORCH_CPU_DEVICE = torch.device("cpu")
|
| 11 |
+
|
| 12 |
+
# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
|
| 13 |
+
if(torch.cuda.device_count() > 0):
|
| 14 |
+
TORCH_CUDA_DEVICE = torch.device("cuda")
|
| 15 |
+
|
| 16 |
+
else:
|
| 17 |
+
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
|
| 18 |
+
print("")
|
| 19 |
+
TORCH_CUDA_DEVICE = None
|
| 20 |
+
|
| 21 |
+
USE_CUDA = True
|
| 22 |
+
|
| 23 |
+
# use_cuda
|
| 24 |
+
def use_cuda(cuda_bool):
|
| 25 |
+
"""
|
| 26 |
+
----------
|
| 27 |
+
Author: Damon Gwinn
|
| 28 |
+
----------
|
| 29 |
+
Sets whether to use CUDA (if available), or use the CPU (not recommended)
|
| 30 |
+
----------
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
global USE_CUDA
|
| 34 |
+
USE_CUDA = cuda_bool
|
| 35 |
+
|
| 36 |
+
# get_device
|
| 37 |
+
def get_device():
|
| 38 |
+
"""
|
| 39 |
+
----------
|
| 40 |
+
Author: Damon Gwinn
|
| 41 |
+
----------
|
| 42 |
+
Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
|
| 43 |
+
----------
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
|
| 47 |
+
return TORCH_CPU_DEVICE
|
| 48 |
+
else:
|
| 49 |
+
return TORCH_CUDA_DEVICE
|
| 50 |
+
|
| 51 |
+
# cuda_device
|
| 52 |
+
def cuda_device():
|
| 53 |
+
"""
|
| 54 |
+
----------
|
| 55 |
+
Author: Damon Gwinn
|
| 56 |
+
----------
|
| 57 |
+
Grabs the cuda device (may be None if CUDA is not available)
|
| 58 |
+
----------
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
return TORCH_CUDA_DEVICE
|
| 62 |
+
|
| 63 |
+
# cpu_device
|
| 64 |
+
def cpu_device():
|
| 65 |
+
"""
|
| 66 |
+
----------
|
| 67 |
+
Author: Damon Gwinn
|
| 68 |
+
----------
|
| 69 |
+
Grabs the cpu device
|
| 70 |
+
----------
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
return TORCH_CPU_DEVICE
|
utilities/lr_scheduling.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Library Imports
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
#Using Adam optimizer with
|
| 5 |
+
#Beta_1=0.9, Beta_2=0.98, and Epsilon=10^-9
|
| 6 |
+
|
| 7 |
+
#Learning rate varies over course of training
|
| 8 |
+
#lrate = sqrt(d_model)*min((1/sqrt(step_num)), step_num*(1/warmup_steps*sqrt(warmup_steps)))
|
| 9 |
+
|
| 10 |
+
# LrStepTracker
|
| 11 |
+
class LrStepTracker:
|
| 12 |
+
"""
|
| 13 |
+
----------
|
| 14 |
+
Author: Ryan Marshall
|
| 15 |
+
Modified: Damon Gwinn
|
| 16 |
+
----------
|
| 17 |
+
Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR).
|
| 18 |
+
|
| 19 |
+
Learn rate for each step (batch) given the warmup steps is:
|
| 20 |
+
lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ]
|
| 21 |
+
|
| 22 |
+
This is from Attention is All you Need (https://arxiv.org/abs/1706.03762)
|
| 23 |
+
----------
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0):
|
| 27 |
+
# Store Values
|
| 28 |
+
self.warmup_steps = warmup_steps
|
| 29 |
+
self.model_dim = model_dim
|
| 30 |
+
self.init_steps = init_steps
|
| 31 |
+
|
| 32 |
+
# Begin Calculations
|
| 33 |
+
self.invsqrt_dim = (1 / math.sqrt(model_dim))
|
| 34 |
+
self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps)))
|
| 35 |
+
|
| 36 |
+
# step
|
| 37 |
+
def step(self, step):
|
| 38 |
+
"""
|
| 39 |
+
----------
|
| 40 |
+
Author: Ryan Marshall
|
| 41 |
+
Modified: Damon Gwinn
|
| 42 |
+
----------
|
| 43 |
+
Method to pass to LambdaLR. Increments the step and computes the new learn rate.
|
| 44 |
+
----------
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
step += self.init_steps
|
| 48 |
+
if(step <= self.warmup_steps):
|
| 49 |
+
return self.invsqrt_dim * self.invsqrt_warmup * step
|
| 50 |
+
else:
|
| 51 |
+
invsqrt_step = (1 / math.sqrt(step))
|
| 52 |
+
return self.invsqrt_dim * invsqrt_step
|
| 53 |
+
|
| 54 |
+
# get_lr
|
| 55 |
+
def get_lr(optimizer):
|
| 56 |
+
"""
|
| 57 |
+
----------
|
| 58 |
+
Author: Damon Gwinn
|
| 59 |
+
----------
|
| 60 |
+
Hack to get the current learn rate of the model
|
| 61 |
+
----------
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
for param_group in optimizer.param_groups:
|
| 65 |
+
return param_group['lr']
|
utilities/run_model.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from .constants import *
|
| 5 |
+
from utilities.device import get_device
|
| 6 |
+
from .lr_scheduling import get_lr
|
| 7 |
+
|
| 8 |
+
from dataset.e_piano import compute_epiano_accuracy
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# train_epoch
|
| 12 |
+
def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1):
|
| 13 |
+
"""
|
| 14 |
+
----------
|
| 15 |
+
Author: Damon Gwinn
|
| 16 |
+
----------
|
| 17 |
+
Trains a single model epoch
|
| 18 |
+
----------
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
out = -1
|
| 22 |
+
model.train()
|
| 23 |
+
for batch_num, batch in enumerate(dataloader):
|
| 24 |
+
time_before = time.time()
|
| 25 |
+
|
| 26 |
+
opt.zero_grad()
|
| 27 |
+
|
| 28 |
+
x = batch[0].to(get_device())
|
| 29 |
+
tgt = batch[1].to(get_device())
|
| 30 |
+
|
| 31 |
+
y = model(x)
|
| 32 |
+
|
| 33 |
+
y = y.reshape(y.shape[0] * y.shape[1], -1)
|
| 34 |
+
tgt = tgt.flatten()
|
| 35 |
+
|
| 36 |
+
out = loss.forward(y, tgt)
|
| 37 |
+
|
| 38 |
+
out.backward()
|
| 39 |
+
opt.step()
|
| 40 |
+
|
| 41 |
+
if(lr_scheduler is not None):
|
| 42 |
+
lr_scheduler.step()
|
| 43 |
+
|
| 44 |
+
time_after = time.time()
|
| 45 |
+
time_took = time_after - time_before
|
| 46 |
+
|
| 47 |
+
if((batch_num+1) % print_modulus == 0):
|
| 48 |
+
print(SEPERATOR)
|
| 49 |
+
print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader))
|
| 50 |
+
print("LR:", get_lr(opt))
|
| 51 |
+
print("Train loss:", float(out))
|
| 52 |
+
print("")
|
| 53 |
+
print("Time (s):", time_took)
|
| 54 |
+
print(SEPERATOR)
|
| 55 |
+
print("")
|
| 56 |
+
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
# eval_model
|
| 60 |
+
def eval_model(model, dataloader, loss):
|
| 61 |
+
"""
|
| 62 |
+
----------
|
| 63 |
+
Author: Damon Gwinn
|
| 64 |
+
----------
|
| 65 |
+
Evaluates the model and prints the average loss and accuracy
|
| 66 |
+
----------
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
model.eval()
|
| 70 |
+
|
| 71 |
+
avg_acc = -1
|
| 72 |
+
avg_loss = -1
|
| 73 |
+
with torch.set_grad_enabled(False):
|
| 74 |
+
n_test = len(dataloader)
|
| 75 |
+
sum_loss = 0.0
|
| 76 |
+
sum_acc = 0.0
|
| 77 |
+
for batch in dataloader:
|
| 78 |
+
x = batch[0].to(get_device())
|
| 79 |
+
tgt = batch[1].to(get_device())
|
| 80 |
+
|
| 81 |
+
y = model(x)
|
| 82 |
+
|
| 83 |
+
sum_acc += float(compute_epiano_accuracy(y, tgt))
|
| 84 |
+
|
| 85 |
+
y = y.reshape(y.shape[0] * y.shape[1], -1)
|
| 86 |
+
tgt = tgt.flatten()
|
| 87 |
+
|
| 88 |
+
out = loss.forward(y, tgt)
|
| 89 |
+
|
| 90 |
+
sum_loss += float(out)
|
| 91 |
+
|
| 92 |
+
avg_loss = sum_loss / n_test
|
| 93 |
+
avg_acc = sum_acc / n_test
|
| 94 |
+
|
| 95 |
+
return avg_loss, avg_acc
|