Spaces:
Runtime error
Runtime error
| import random | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from model.music_transformer import MusicTransformer | |
| from processor import decode_midi, encode_midi | |
| from utilities.constants import TOKEN_END, TOKEN_PAD, TORCH_LABEL_TYPE | |
| from utilities.device import get_device, use_cuda | |
| REPO_ID = "Launchpad/lofi-bytes" | |
| FILENAME = "weights_maestro_finetuned.pickle" | |
| SEQUENCE_START = 0 | |
| OUTPUT_PATH = "./output_midi" | |
| RPR = True | |
| # TARGET_SEQ_LENGTH = 1023 | |
| TARGET_SEQ_LENGTH = 512 | |
| NUM_PRIME = 65 | |
| MAX_SEQUENCE = 2048 | |
| N_LAYERS = 6 | |
| NUM_HEADS = 8 | |
| D_MODEL = 512 | |
| DIM_FEEDFORWARD = 1024 | |
| BEAM = 0 | |
| FORCE_CPU = False | |
| ALLOWED_EXTENSIONS = {'mid'} | |
| UPLOAD_FOLDER = './uploaded_midis' | |
| generated_midi = None | |
| use_cuda(True) | |
| model = MusicTransformer( | |
| n_layers=N_LAYERS, | |
| num_heads=NUM_HEADS, | |
| d_model=D_MODEL, | |
| dim_feedforward=DIM_FEEDFORWARD, | |
| max_sequence=MAX_SEQUENCE, | |
| rpr=RPR | |
| ).to(get_device()) | |
| state_dict = torch.load( | |
| hf_hub_download(repo_id=REPO_ID, filename=FILENAME), | |
| map_location=get_device() | |
| ) | |
| model.load_state_dict(state_dict) | |
| def generate(input_midi): | |
| raw_mid = encode_midi(input_midi) | |
| if(len(raw_mid) == 0): | |
| return | |
| primer, _ = process_midi(raw_mid, NUM_PRIME, random_seq=False) | |
| primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device()) | |
| # saves a pretty_midi at file_path | |
| # decode_midi(primer[:NUM_PRIME].cpu().numpy(), file_path=f_path) | |
| decode_midi(primer[:NUM_PRIME].cpu().numpy()) | |
| # GENERATION | |
| model.eval() | |
| with torch.set_grad_enabled(False): | |
| # NOTE: model.generate() returns a MIDI stored as an ARRAY given a primer | |
| beam_seq = model.generate(primer[:NUM_PRIME], TARGET_SEQ_LENGTH, beam=BEAM) | |
| file_path = "output.mid" | |
| # NOTE: function decode_midi() returns an actual MIDI of class pretty_midi.PrettyMIDI | |
| decoded_midi = decode_midi(beam_seq[0].cpu().numpy(), file_path=file_path) | |
| # THIS SHOULD BE EITHER decoded_midi OR beam_seq | |
| # TODO: decoded_midi is actual pretty_midi MIDI file, beam_seq is just an array representing a MIDI | |
| # decoded_midi stores more information about instruments and stuff | |
| return file_path | |
| def process_midi(raw_mid, max_seq, random_seq): | |
| """ | |
| ---------- | |
| Author: Damon Gwinn | |
| ---------- | |
| Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or | |
| go from the start based on random_seq. | |
| ---------- | |
| """ | |
| x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device()) | |
| tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device()) | |
| raw_len = len(raw_mid) | |
| full_seq = max_seq + 1 # Performing seq2seq | |
| if(raw_len == 0): | |
| return x, tgt | |
| if(raw_len < full_seq): | |
| x[:raw_len] = raw_mid | |
| tgt[:raw_len-1] = raw_mid[1:] | |
| tgt[raw_len] = TOKEN_END | |
| else: | |
| # Randomly selecting a range | |
| if(random_seq): | |
| end_range = raw_len - full_seq | |
| start = random.randint(SEQUENCE_START, end_range) | |
| # Always taking from the start to as far as we can | |
| else: | |
| start = SEQUENCE_START | |
| end = start + full_seq | |
| data = raw_mid[start:end] | |
| x = data[:max_seq] | |
| tgt = data[1:full_seq] | |
| return x, tgt | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Image( | |
| "https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/410912267_278779401866686_2517511436172822307_n_0iVwDxI.png", | |
| elem_id="logo-img", | |
| show_label=False, | |
| show_share_button=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("""lofi-bytes is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Spring 2023) that generates lofi tracks from input MIDI stamples using a MusicTransformer model. | |
| <br/><br/> | |
| **Model**: [lofi-bytes](https://huggingface.co/Launchpad/lofi-bytes) | |
| <br/> | |
| **Project Leader**: Alicia Wang | |
| <br/> | |
| **Members**: Alena Chao, Eric Liu, Zane Mogannam, Chloe Wong, Iris Zhou | |
| <br/> | |
| **Advisors**: Vincent Lim, Winston Liu | |
| <br/> | |
| """ | |
| ) | |
| gr.Interface( | |
| fn=generate, | |
| inputs=gr.File(), | |
| outputs=gr.File(), | |
| examples=["uploaded_midis/ghibli_castle_in_the_sky.mid", "uploaded_midis/am_i_blue_jazz.mid"] | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch(share=True, show_error=True) | |