File size: 4,873 Bytes
7116323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bde8bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import random

import gradio as gr
import torch
from huggingface_hub import hf_hub_download

from model.music_transformer import MusicTransformer
from processor import decode_midi, encode_midi
from utilities.constants import TOKEN_END, TOKEN_PAD, TORCH_LABEL_TYPE
from utilities.device import get_device, use_cuda

REPO_ID = "Launchpad/lofi-bytes"
FILENAME = "weights_maestro_finetuned.pickle"

SEQUENCE_START = 0
OUTPUT_PATH = "./output_midi"
RPR = True
# TARGET_SEQ_LENGTH = 1023
TARGET_SEQ_LENGTH = 512
NUM_PRIME = 65
MAX_SEQUENCE = 2048
N_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 512
DIM_FEEDFORWARD = 1024
BEAM = 0
FORCE_CPU = False
ALLOWED_EXTENSIONS = {'mid'}
UPLOAD_FOLDER = './uploaded_midis'

generated_midi = None

use_cuda(True)

model = MusicTransformer(
    n_layers=N_LAYERS,
    num_heads=NUM_HEADS,
    d_model=D_MODEL,
    dim_feedforward=DIM_FEEDFORWARD,
    max_sequence=MAX_SEQUENCE,
    rpr=RPR
).to(get_device())

state_dict = torch.load(
    hf_hub_download(repo_id=REPO_ID, filename=FILENAME),
    map_location=get_device()
)

model.load_state_dict(state_dict)

def generate(input_midi):

    raw_mid = encode_midi(input_midi)
    if(len(raw_mid) == 0):
        return

    primer, _  = process_midi(raw_mid, NUM_PRIME, random_seq=False)
    primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())

    # saves a pretty_midi at file_path
    # decode_midi(primer[:NUM_PRIME].cpu().numpy(), file_path=f_path)
    decode_midi(primer[:NUM_PRIME].cpu().numpy())

    # GENERATION
    model.eval()
    with torch.set_grad_enabled(False):

        # NOTE: model.generate() returns a MIDI stored as an ARRAY given a primer
        beam_seq = model.generate(primer[:NUM_PRIME], TARGET_SEQ_LENGTH, beam=BEAM)

        file_path = "output.mid"

        # NOTE: function decode_midi() returns an actual MIDI of class pretty_midi.PrettyMIDI
        decoded_midi = decode_midi(beam_seq[0].cpu().numpy(), file_path=file_path)

        # THIS SHOULD BE EITHER decoded_midi OR beam_seq
        # TODO: decoded_midi is actual pretty_midi MIDI file, beam_seq is just an array representing a MIDI
        # decoded_midi stores more information about instruments and stuff
        return file_path

def process_midi(raw_mid, max_seq, random_seq):
    """
    ----------
    Author: Damon Gwinn
    ----------
    Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
    go from the start based on random_seq.
    ----------
    """

    x   = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
    tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())

    raw_len     = len(raw_mid)
    full_seq    = max_seq + 1 # Performing seq2seq

    if(raw_len == 0):
        return x, tgt

    if(raw_len < full_seq):
        x[:raw_len]         = raw_mid
        tgt[:raw_len-1]     = raw_mid[1:]
        tgt[raw_len]        = TOKEN_END
    else:
        # Randomly selecting a range
        if(random_seq):
            end_range = raw_len - full_seq
            start = random.randint(SEQUENCE_START, end_range)

        # Always taking from the start to as far as we can
        else:
            start = SEQUENCE_START

        end = start + full_seq

        data = raw_mid[start:end]

        x = data[:max_seq]
        tgt = data[1:full_seq]

    return x, tgt


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image(
                "https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/410912267_278779401866686_2517511436172822307_n_0iVwDxI.png",
                elem_id="logo-img",
                show_label=False,
                show_share_button=False,
                show_download_button=False,
                show_fullscreen_button=False,
            )

        with gr.Column(scale=3):
            gr.Markdown("""lofi-bytes is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Spring 2023) that generates lofi tracks from input MIDI stamples using a MusicTransformer model.
                        <br/><br/>
                        **Model**: [lofi-bytes](https://huggingface.co/Launchpad/lofi-bytes)
                        <br/>
                        **Project Leader**: Alicia Wang
                        <br/>
                        **Members**: Alena Chao, Eric Liu, Zane Mogannam, Chloe Wong, Iris Zhou
                        <br/>
                        **Advisors**: Vincent Lim, Winston Liu
                        <br/>
                        """
                        )
    gr.Interface(
        fn=generate,
        inputs=gr.File(),
        outputs=gr.File(),
        examples=["uploaded_midis/ghibli_castle_in_the_sky.mid", "uploaded_midis/am_i_blue_jazz.mid"]
    )

if __name__ == '__main__':
    demo.launch(share=True, show_error=True)