Paul Triana commited on
Commit
6229e10
·
1 Parent(s): 61c7027

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +87 -0
  2. .gitmodules +6 -0
  3. CMakeLists.txt +87 -0
  4. CMakeSettings.json +28 -0
  5. README.md +176 -0
  6. create_dataset_compute_canada.sh +77 -0
  7. create_python_library.sh +232 -0
  8. data_split.sh +16 -0
  9. include/dataset_creation/compression/lz4.h +764 -0
  10. include/dataset_creation/dataset_manipulation/bytes_to_file.h +27 -0
  11. libraries/midifile +1 -0
  12. libraries/protobuf/CMakeLists.txt +28 -0
  13. libraries/protobuf/CMakeSettings.json +29 -0
  14. libraries/protobuf/include/proto_library.h +5 -0
  15. libraries/protobuf/src/enum.proto +389 -0
  16. libraries/protobuf/src/feature_extraction.proto +92 -0
  17. libraries/protobuf/src/midi.proto +429 -0
  18. libraries/protobuf/src/midi_internal.proto +125 -0
  19. libraries/protobuf/src/track_type.proto +12 -0
  20. libraries/pybind11 +1 -0
  21. libraries/torch/CMakeLists.txt +26 -0
  22. libraries/torch/CMakeSettings.json +29 -0
  23. libraries/torch/include/torch_library.h +5 -0
  24. libraries/torch/src/torch_library.cpp +10 -0
  25. midigpt_setup_helper.sh +182 -0
  26. models/model.zip +3 -0
  27. pip_requirements/common_requirements.txt +1 -0
  28. pip_requirements/create_dataset_requirements.txt +4 -0
  29. pip_requirements/inference_requirements.txt +1 -0
  30. pip_requirements/train_requirements.txt +135 -0
  31. python_scripts/config/bert.json +7 -0
  32. python_scripts/config/bert_tiny.json +7 -0
  33. python_scripts/config/gpt2.json +7 -0
  34. python_scripts/config/gpt2_tiny.json +7 -0
  35. python_scripts/convert.py +204 -0
  36. python_scripts/create_dataset.py +226 -0
  37. python_scripts/custom_models.py +121 -0
  38. python_scripts/data_split.py +54 -0
  39. python_scripts/losses.py +37 -0
  40. python_scripts/train.py +224 -0
  41. python_scripts/train_dataset.py +96 -0
  42. python_scripts/utils.py +41 -0
  43. python_scripts_for_testing/midigpt_gen.mid +0 -0
  44. python_scripts_for_testing/mtest.mid +0 -0
  45. python_scripts_for_testing/pythoninferencetest.py +67 -0
  46. setup.py +67 -0
  47. src/common/data_structures/encoder_config.h +95 -0
  48. src/common/data_structures/token_sequence.h +38 -0
  49. src/common/data_structures/track_type.h +20 -0
  50. src/common/data_structures/train_config.cpp +46 -0
.gitignore ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .gitignore
2
+
3
+ build.sh
4
+
5
+ # feature_extraction stuff
6
+ python_experiment_scripts/**/*.json
7
+ python_experiment_scripts/**/*.csv
8
+ python_experiment_scripts/**/*.pdf
9
+ python_experiment_scripts/**/*.numbers
10
+ python_experiment_scripts/**/*.txt
11
+ python_experiment_scripts/examples/**/*.mid
12
+
13
+ python_scripts_for_testing/examples/*.mid
14
+ python_scripts_for_testing/examples/
15
+
16
+ # Ignore compilation folders / dependencies
17
+ #midifile/
18
+ libtorch/
19
+ libtorch_cpu/
20
+ #pybind11/
21
+ build/
22
+ python_lib/
23
+ CMakeFiles/
24
+ CMakeScripts/
25
+ CMakeCache.txt
26
+ cmake_install.cmake
27
+ Debug/
28
+
29
+ # Ignore the libtorch directory
30
+ /libraries/libtorch
31
+ /libraries/libtorch_cpu
32
+ #/libraries/pybind11
33
+ #/libraries/midifile
34
+
35
+ /python_lib/
36
+ /python_scripts/training_files
37
+
38
+ # ignore dataset files
39
+ *.arr
40
+ *.arr.header
41
+
42
+ # ignore dstore
43
+ .DS_Store
44
+
45
+ # Ignore the model.pt file
46
+ **/*.pt
47
+ *.pt
48
+
49
+ #ignore visual studio metadata
50
+ .vs/
51
+ .vscode
52
+
53
+ #ignore build directory
54
+ /out/
55
+
56
+ *.out
57
+
58
+ #ignore random files
59
+ 2539bytes.txt
60
+ CMakeLists_CPP.txt
61
+
62
+ /libraries/protobuf/.vs/
63
+ /libraries/protobuf/out/
64
+
65
+ /libraries/torch/.vs/
66
+ /libraries/torch/out/
67
+
68
+ # exceptions
69
+ !python_experiment_scripts/onset_density_test.pdf
70
+ !python_experiment_scripts/onset_density_test.json
71
+ !python_experiment_scripts/onset_polyphony_test.json
72
+ !python_experiment_scripts/musicmap_genre_data.json
73
+
74
+ !python_experiment_scripts/onset_polyphony_test_mono.json
75
+ !python_experiment_scripts/onset_polyphony_test_poly.json
76
+ !python_experiment_scripts/onset_polyphony_v_original_test_mono.json
77
+ !python_experiment_scripts/onset_polyphony_v_original_test_poly.json
78
+
79
+ # pycache folders
80
+ */__pycache__/*
81
+
82
+ #training outputs
83
+ */checkpoints/*
84
+ */logs/*
85
+ */midigpt.cpython-38-x86_64-linux-gnu.so
86
+
87
+ notes/*
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "libraries/pybind11"]
2
+ path = libraries/pybind11
3
+ url = https://github.com/pybind/pybind11.git
4
+ [submodule "libraries/midifile"]
5
+ path = libraries/midifile
6
+ url = https://github.com/craigsapp/midifile.git
CMakeLists.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.8)
2
+
3
+ SET(CMAKE_CXX_STANDARD 20)
4
+ SET(CMAKE_CXX_STANDARD_REQUIRED ON)
5
+ SET(CMAKE_POSITION_INDEPENDENT_CODE ON)
6
+ #we add the following line to fix a linkage issue between torch and midifile
7
+ #https://stackoverflow.com/questions/68922557/c-linker-error-undefined-reference-when-linking-package-libtorch-and-shared
8
+ #add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
9
+
10
+ project(midigpt)
11
+
12
+ option(compute_canada "Build for Compute Canada" OFF)
13
+ option(mac_os "Build for Mac OS" OFF)
14
+ option(no_torch "No Torch" OFF)
15
+ option(no_pybind "No Pybind" OFF)
16
+ option(trace "Trace" OFF)
17
+
18
+ #Find the necessary packages to be able to link the libraries correctly
19
+ find_package(Protobuf REQUIRED)
20
+ include_directories(${Protobuf_INCLUDE_DIRS})
21
+
22
+ if(no_torch)
23
+ add_definitions(-DNO_TORCH)
24
+ endif()
25
+
26
+ if(NOT no_torch)
27
+ if(mac_os)
28
+ message("USING PYTHON PYTORCH INSTEAD")
29
+ else()
30
+ set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libraries/libtorch/")
31
+ endif()
32
+ find_package(Torch REQUIRED)
33
+
34
+ # This is necessary to avoid a symbol linkage error https://github.com/pytorch/pytorch/issues/38122
35
+ # https://github.com/DeepVAC/libdeepvac/blob/master/python/CMakeLists.txt
36
+ find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
37
+ endif()
38
+
39
+ if(compute_canada)
40
+ include_directories("/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx512/Core/python/3.8.2/include/python3.8")
41
+ endif()
42
+
43
+ #Add the directories of libraries so the project can CMake them too
44
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/protobuf)
45
+ if(NOT no_torch)
46
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/torch)
47
+ endif()
48
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/pybind11)
49
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/midifile)
50
+
51
+ #https://stackoverflow.com/questions/8934295/add-source-in-a-subdirectory-to-a-cmake-project/54285898#54285898
52
+ #https://crascit.com/2016/01/31/enhanced-source-file-handling-with-target_sources/
53
+
54
+
55
+ set(SRCS
56
+ src/common/data_structures/train_config.cpp
57
+ src/dataset_creation/compression/lz4.c
58
+ src/dataset_creation/dataset_manipulation/bytes_to_file.cpp
59
+ src/common/encoder/encoder_all.h
60
+ src/lib.cpp
61
+ )
62
+ PYBIND11_ADD_MODULE(midigpt ${SRCS})
63
+
64
+ #Adding include folders of libraries to our target so we can reference them with #include
65
+ #Add subdirectory adds those to main project so they can be CMAKEd. Include dirs allows us to reference functions in main.
66
+ target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/protobuf/include)
67
+ if (NOT no_torch)
68
+ target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/torch/include)
69
+ endif()
70
+ target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/midifile/include)
71
+
72
+ #Linking all the libraries
73
+ target_link_libraries(midigpt PRIVATE midigpt_proto) #Our protobuf custom library
74
+ target_link_libraries(midigpt PRIVATE midifile)
75
+ if (NOT no_torch)
76
+ target_link_libraries(midigpt PRIVATE midigpt_torch) #Our torch custom library
77
+ #This is necessary to avoid a symbol linkage error https://github.com/pytorch/pytorch/issues/38122
78
+ target_link_libraries(midigpt PRIVATE "${TORCH_LIBRARIES}" ${TORCH_PYTHON_LIBRARY})
79
+ endif()
80
+
81
+ if (trace)
82
+ add_library(tracer STATIC src/trace.cpp)
83
+ target_link_libraries(midigpt PRIVATE tracer)
84
+ target_compile_options(midigpt PRIVATE -Wall -Wextra -Wpedantic -finstrument-functions)
85
+ elseif(NOT WIN32)
86
+ target_compile_options(midigpt PRIVATE -Wall -Wextra -Wpedantic)
87
+ endif()
CMakeSettings.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "configurations": [
3
+ {
4
+ "name": "x64-Debug",
5
+ "generator": "Ninja",
6
+ "configurationType": "Debug",
7
+ "inheritEnvironments": [ "msvc_x64_x64" ],
8
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
9
+ "installRoot": "${projectDir}\\out\\install\\${name}",
10
+ "cmakeCommandArgs": "",
11
+ "buildCommandArgs": "",
12
+ "ctestCommandArgs": ""
13
+ },
14
+ {
15
+ "name": "WSL-GCC-Debug",
16
+ "generator": "Ninja",
17
+ "configurationType": "Debug",
18
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
19
+ "installRoot": "${projectDir}\\out\\install\\${name}",
20
+ "cmakeExecutable": "cmake",
21
+ "cmakeCommandArgs": "",
22
+ "buildCommandArgs": "",
23
+ "ctestCommandArgs": "",
24
+ "inheritEnvironments": [ "linux_x64" ],
25
+ "wslPath": "${defaultWSLPath}"
26
+ }
27
+ ]
28
+ }
README.md CHANGED
@@ -9,3 +9,179 @@ short_description: MIDI-GPT-inference-docker
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+
14
+ [![N|Solid](https://drive.google.com/uc?export=view&id=1u4xiWN3s0PAii8zn3-qxJ7wn35tBOypY)](https://metacreation.net/category/projects/)
15
+
16
+ # MIDI-GPT Guide
17
+
18
+ This is the repository for MIDI-GPT, a generative system based on the Transformer architecture that is designed for computer-assisted music composition workflows. This work was presented at the 39th Annual AAAI Conference in Philadelphia, USA in this [paper](https://arxiv.org/abs/2501.17011)
19
+
20
+ # Using MIDI-GPT
21
+
22
+ ## Installation
23
+
24
+ To successfully install the midigpt python library, use the script ```midigpt_setup_helper.sh```. You may first download this script on its own and run it, which will clone the repository and build the library. Below is an example of the usage:
25
+
26
+ ```sh
27
+ bash midigpt_setup_helper.sh -d midigpt_dir
28
+ ```
29
+
30
+ >**Note:** Python 3.8 is required for the library
31
+ >**Note:** If you're building on mac, use the ```-m```argument.
32
+
33
+ ## Inference
34
+
35
+ Once downloaded, MIDI-GPT is ready to use. ```python_scripts_for_testing/pythoninferencetest.py``` is an example of using MIDI-GPT. In summary, three objects need to be created before sampling:
36
+ - Piece: Load the MIDI file into a JSON representation of the MIDI piece
37
+ - Status: This dict indicates the sampling process that is desired (on which tracks, continuation/resampling/infilling, etc.) as well as attribute control values
38
+ - Param: This dict indicates sampling parameters such as temperature or number of generated bars per step
39
+
40
+ You must provide an input MIDI file, the checkpoint model file and an optional output MIDI file. Our model is provided in the ```models/model.zip``` file.
41
+
42
+ Then, using the ```midigpt``` Python API, call the sample function with these objects as arguments. After sampling, the result can then be converted and saved into a MIDI file.
43
+
44
+ # Training MIDI-GPT
45
+
46
+ Training the model was done on computing clusters on Compute Canada, therefore the training scripts are tailored to this platform but may easily be adapted to similar platforms. Training was done using the GigaMIDI dataset, first serialzed into a compressed file using ```create_dataset_compute_canada.sh``` and ```python_scripts/create_dataset.py```. The training was executed using the ```python_scripts/train.py```. Finally, the model weights file is converted from the training checkpoint using ```convert.py```.
47
+
48
+ If you're unfamiliar with Compute Canada, make sure to check the introductory .md [here]().
49
+
50
+ ## Installation - Cedar and Niagara
51
+ 0. You might want to allocate an interactive session with salloc:
52
+
53
+ >**Note:** You DON'T need to do this in Niagara.
54
+
55
+ ```sh
56
+ salloc --time=3:0:0 --nodes 1 --cpus-per-task 32 --mem=128000 --account=user
57
+ ```
58
+
59
+ 1. First, make sure to clone the MMM_API into a folder in your CC machine:
60
+ ```sh
61
+ https://github.com/Metacreation-Lab/MIDI-GPT
62
+ ```
63
+ 2. Then we must load the standard environments and some dependencies:
64
+
65
+ >**Note:** If you're building in Niagara, load this first:
66
+ ```sh
67
+ module load CCEnv arch/avx512
68
+ ```
69
+ Then proceed to load the rest (If you're in Cedar, start from here):
70
+ ```sh
71
+ module load StdEnv/2020
72
+ module load cmake/3.23.1
73
+ module load gcc/11.3.0
74
+ module load protobuf/3.12.3
75
+ module load python/3.8.2
76
+ ```
77
+ 3. Then we must create an environment and activate it:
78
+ ```sh
79
+ virtualenv --no-download ./ENV # ENV is the name of the environment
80
+ source ./ENV/bin/activate
81
+ pip install --no-index --upgrade pip
82
+
83
+ # For training only
84
+ pip install torch==1.13.0+computecanada
85
+ pip install transformers==4.26.1+computecanada
86
+ ```
87
+ 4. Finally, just call the bash script with the correct argument:
88
+ ```sh
89
+ bash create_python_library.sh --test_build --compute_canada
90
+ ```
91
+ Or if you are planning to just train the model, add the argument excluding to torch library required only for inference:
92
+ ```sh
93
+ bash create_python_library.sh --no_torch --compute_canada
94
+ ```
95
+ 5. To test the library imports for training, run the train.py script by importing it:
96
+ ```sh
97
+ cd python_scripts
98
+ python3 -c "import train"
99
+ ```
100
+ > **Note:** A helper script ```midigpt_setup_helper.sh``` does all these steps autmoatically (for training or inference). Download it individually and run it where you wish to clone the repository.
101
+ > **Note:** If you run the code without the --test_build flag, it will still compile and create the python library but it won't test it with the current model in production.
102
+ > **Note:** The other flag (--compute_canada) is necesary to build the code properly.
103
+
104
+ That's it!
105
+
106
+ Everything should get installed correctly in your python environment! If you log out and back in to CC make sure to activate the environment in which you installed the API.
107
+
108
+ ## Training
109
+
110
+ ### Dataset Building
111
+
112
+ In order to train a new model, you must first build a dataset. You can upload the files you need using Globus (check the CC [guide]()).
113
+
114
+ > **Note**: Remember that to copy from the shared folder to your own folders you must use absolute paths.
115
+
116
+ The data should be organized in a way where all midi files are contained within three folders ```train```, ```test```, and ```valid```. Further directories can be used to organize the midi files as long as they are within these three directories.
117
+
118
+ If your dataset is a single folder containing all the midi files, we provide a helper script that automatically slits the dataset to 80%-10%-10%. Simply modify ```data_split.sh``` to match your cas and run.
119
+
120
+ Once you have the folder with the data, run the following command
121
+ ```sh
122
+ sh create_dataset_compute_canada.sh --root_dir=<root_dir> --encoding=<encoding> --data_dir=<data_dir> --output=<output>
123
+ ```
124
+ where:
125
+ - ```<root_dir>``` is the root folder where the midigpt repository folder is located
126
+ - ```<encoding>``` is the conder to use. We suggest using ```EXPRESSIVE_ENCODER```
127
+ - ```<data_dir>``` is the dataset folder containing the three ```train```, ```test```, and ```valid``` folders.
128
+ - ```<output>``` is the location of the ouptt ```.arr``` file. The resulting file while be ```<output>_NUM_BARS=<num_bars>_RESOLUTION_<resolution>.arr```
129
+ >**Note:** If you are on Compute Canada, we suggest you run these commands through an sbatch job as they can take some time.
130
+
131
+ ### Training a Model
132
+
133
+ To train a model, run the train.py file. Different lab members have managed to set the paths differently. What works for me is to use global paths. An example would be:
134
+ ```sh
135
+ python train.py --arch gpt2 --config /home/user/scratch/TRAINING-master/config/gpt2_tiny.json --encoding EXPRESSIVE_ENCODER --ngpu 4 --dataset /home/user/scratch/test_NUM_BARS=4_OPZ_False.arr --batch_size 32 --label DELETE_ME
136
+ ```
137
+
138
+ ### Running Jobs
139
+
140
+ To read the CC documentation, cick [here](https://docs.alliancecan.ca/wiki/Running_jobs). You can run small snippets of code to test things out without allocating any resources. However, to train a model or perform any time/resource consuming task, you must schedule a job. A list of different types of job scheduling will be added here.
141
+
142
+ #### Interactive Jobs
143
+ You can start an interactive session on a compute node with salloc.
144
+ ```sh
145
+ salloc --time=3:0:0 --nodes 1 --cpus-per-task 32 --mem=128000 --account=user
146
+ ```
147
+
148
+ #### Scheduled jobs (use this for training)
149
+ For time-expensive tasks it is better to create a bash file and submit a job with sbatch:
150
+ ```sh
151
+ sbatch simple_job.sh
152
+ ```
153
+
154
+ Here is an example of the contents of a bash file to submit a midigpt training job:
155
+ ```sh
156
+ #!/bin/bash
157
+ #SBATCH --gres=gpu:v100l:4
158
+ #SBATCH --cpus-per-task=32
159
+ #SBATCH --exclusive
160
+ #SBATCH --mem=0
161
+ #SBATCH --time=2-23:00
162
+ #SBATCH --account=user
163
+ #SBATCH --mail-user [email protected] <---- MAKE SURE TO PUT YOUR EMAIL
164
+ #SBATCH --mail-type ALL
165
+ #SBATCH --output=CCLOG/FILENAME.out <---- MAKE SURE TO CHANGE THE NAME OF THE FILE
166
+
167
+ source $SCRATCH/PY_3610/bin/activate <---- THIS IS THE DIRECTORY TO THE ENV WHERE YOU HAVE THE midigpt_api INSTALLED
168
+ cd $SCRATCH/MMM_TRAINING-master
169
+ module load StdEnv/2020 protobuf python/3.6.10
170
+ source $SCRATCH/PY_3610/bin/activate <---- SAME HERE, MAKE SURE THE DIRECTORY IS PLACED CORRECTLY
171
+ python train.py --arch reformer --config /home/user/scratch/MMM_TRAINING-master/config/reformer.json --encoding EXPRESSIVE_ENCODER --ngpu 4 --dataset /home/user/scratch/dataset_NUM_BARS=4.arr --batch_size 32 --label DELETE_ME
172
+ ```
173
+
174
+ In this case we are using 4 v1001 GPUs (**gres** argument) and we're asking for 2 days and 23 hours of time to run the job (**time** argument).
175
+
176
+ #### Check jobs and eliminate session
177
+ To show all the users
178
+ ```sh
179
+ who -u
180
+ ```
181
+
182
+ To kill all the sessions
183
+ ```sh
184
+ pkill -u username
185
+ ```
186
+
187
+
create_dataset_compute_canada.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --cpus-per-task=32
3
+ #SBATCH --time=10:00:00
4
+
5
+ root_dir="" # Model directory to replace MODEL_NAME
6
+ data_dir="" # Data directory to replace DATA_DIR
7
+ encoding="" # Encoding to replace ENCODING
8
+ output="" # Output to replace OUTPUT
9
+ metadata=""
10
+ test="no"
11
+ zip="no"
12
+ res="12"
13
+ max="-1"
14
+
15
+ # Parse arguments
16
+ for arg in "$@"
17
+ do
18
+ case $arg in
19
+ --root_dir=*)
20
+ root_dir="${arg#*=}"
21
+ shift # Remove --root_dir= from processing
22
+ ;;
23
+ --metadata=*)
24
+ metadata="${arg#*=}"
25
+ shift # Remove --metadata= from processing
26
+ ;;
27
+ --res=*)
28
+ res="${arg#*=}"
29
+ shift # Remove --metadata= from processing
30
+ ;;
31
+ --test=*)
32
+ test="${arg#*=}"
33
+ shift # Remove --metadata= from processing
34
+ ;;
35
+ --data_dir=*)
36
+ data_dir="${arg#*=}"
37
+ shift # Remove --data_dir= from processing
38
+ ;;
39
+ --encoding=*)
40
+ encoding="${arg#*=}"
41
+ shift # Remove --encoding= from processing
42
+ ;;
43
+ --output=*)
44
+ output="${arg#*=}"
45
+ shift # Remove --output= from processing
46
+ ;;
47
+ --zip=*)
48
+ zip="${arg#*=}"
49
+ shift # Remove --output= from processing
50
+ ;;
51
+ --max=*)
52
+ max="${arg#*=}"
53
+ shift # Remove --output= from processing
54
+ ;;
55
+ esac
56
+ done
57
+
58
+ module load CCEnv arch/avx512
59
+ module load StdEnv/2020
60
+ module load cmake/3.23.1
61
+ module load gcc/11.3.0
62
+ module load protobuf/3.12.3
63
+ module load python/3.8.2
64
+
65
+ mkdir -p $root_dir/CCLOG
66
+ source $root_dir/venv/bin/activate
67
+
68
+ cp $root_dir/MIDI-GPT/python_lib/midigpt.cpython-38-x86_64-linux-gnu.so $root_dir/MIDI-GPT/python_scripts
69
+
70
+ python3 $root_dir/MIDI-GPT/python_scripts/create_dataset.py --nthreads 40 --max_size $max --data_dir $data_dir --encoding $encoding --output $output --metadata $metadata --test $test --expressive --resolution $res
71
+
72
+ if [[ "$zip" == "yes" ]]
73
+ then
74
+ cd $output
75
+ cd ../
76
+ zip -r EXPRESSIVE_GIGAMIDI_24_1920.zip $output
77
+ fi
create_python_library.sh ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ compute_canada=false
4
+ mac_os=false
5
+ cpu=false
6
+ no_torch=false
7
+ niagara=false
8
+ env_name="venv"
9
+ parent_dir="midigpt_workspace" # default parent directory name
10
+ cuda=false
11
+ trace=false
12
+
13
+
14
+ # Parse arguments
15
+ for ((i=1;i<=$#;i++)); do
16
+ case ${!i} in
17
+ --trace)
18
+ trace=true
19
+ ;;
20
+ --compute_canada)
21
+ compute_canada=true
22
+ ;;
23
+ --cpu)
24
+ cpu=true
25
+ ;;
26
+ --mac_os)
27
+ mac_os=true
28
+ ;;
29
+ --no_torch)
30
+ no_torch=true
31
+ ;;
32
+ --niagara)
33
+ niagara=true
34
+ ;;
35
+ --env_name)
36
+ i=$((i+1))
37
+ env_name=${!i}
38
+ ;;
39
+ -n=*|--name=*) # new parent directory name option
40
+ parent_dir="${!i#*=}"
41
+ shift
42
+ ;;
43
+ --cuda)
44
+ if $no_torch; then
45
+ echo "Cannot use --cuda and --no_torch at the same time."
46
+ exit 1
47
+ fi
48
+ cuda=true
49
+ ;;
50
+ *)
51
+ echo "Unknown option ${!i}"
52
+ exit 1
53
+ ;;
54
+ esac
55
+ done
56
+
57
+
58
+ # Get the current directory name
59
+ dir_name=$(basename `pwd`)
60
+
61
+ # Get the parent directory name
62
+ parent_dir_name=$(basename $(dirname `pwd`))
63
+
64
+ if [ "$parent_dir_name" != "$parent_dir" ]; then
65
+
66
+ # Go to the parent directory
67
+ cd ..
68
+
69
+ # Create the new parent directory
70
+ mkdir $parent_dir
71
+
72
+ # Move the old directory into the new parent directory
73
+ mv $dir_name $parent_dir/
74
+
75
+ # Change to the old directory, which is now inside the parent directory
76
+ cd $parent_dir/$dir_name
77
+ fi
78
+
79
+ # Load modules if we are in compute_canada and niagara
80
+ if $compute_canada; then
81
+ if $niagara; then
82
+ module load CCEnv arch/avx512
83
+ fi
84
+ module load StdEnv/2020
85
+ module load cmake/3.23.1
86
+ module load gcc/11.3.0
87
+ module load protobuf/3.12.3
88
+ module load python/3.8.2
89
+
90
+ mkdir ../CCLOG
91
+ fi
92
+
93
+ # Environment creation
94
+ if [[ -n ../$env_name ]]; then
95
+ if [[ -d ../$env_name ]]; then
96
+ echo "Environment $env_name already exists, activating it..."
97
+ else
98
+ echo "Environment $env_name does not exist, creating it..."
99
+ if $compute_canada; then
100
+ virtualenv ../$env_name
101
+ else
102
+ python3 -m venv ../$env_name
103
+ fi
104
+ fi
105
+ fi
106
+
107
+ source ../$env_name/bin/activate
108
+
109
+ # Install requirements
110
+ pip install -r pip_requirements/common_requirements.txt
111
+
112
+ if $compute_canada; then
113
+ pip install -r pip_requirements/create_dataset_requirements.txt
114
+ fi
115
+ if $mac_os; then
116
+ pip install -r pip_requirements/inference_requirements.txt
117
+ fi
118
+
119
+ if $compute_canada && ! $niagara; then # anf if no torch
120
+ pip install -r pip_requirements/train_requirements.txt
121
+ fi
122
+
123
+ #deactivate
124
+
125
+ # Set CMake flags based on command line arguments
126
+ cmake_flags=""
127
+ if $compute_canada; then
128
+ cmake_flags="$cmake_flags -Dcompute_canada=ON"
129
+ fi
130
+
131
+ if $no_torch; then
132
+ cmake_flags="$cmake_flags -Dno_torch=ON"
133
+ fi
134
+
135
+ if $trace; then
136
+ cmake_flags="$cmake_flags -Dtrace=ON"
137
+ fi
138
+
139
+ # Code to check if libtorch and pybind11 are already downloaded
140
+ if ! $no_torch; then
141
+ libtorch_path="libraries/libtorch"
142
+ libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip"
143
+ if $cuda; then
144
+ libtorch_url="https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip"
145
+ fi
146
+ fi
147
+
148
+ pybind11_path="libraries/pybind11"
149
+ midifile_path="libraries/midifile"
150
+
151
+
152
+ pybind11_url="https://github.com/pybind/pybind11.git"
153
+ midifile_url="https://github.com/craigsapp/midifile"
154
+
155
+ if ! $no_torch; then
156
+ if $mac_os; then
157
+ libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.0.1.zip"
158
+ fi
159
+
160
+ if $cpu; then
161
+ libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip"
162
+ fi
163
+
164
+ # Check if libtorch folder exists and is not empty
165
+ if [ ! -d "$libtorch_path" ] || [ -z "$(ls -A "$libtorch_path")" ]; then
166
+ echo "libtorch folder does not exist or is empty. Downloading and extracting..."
167
+ mkdir -p "$libtorch_path"
168
+ curl -L "$libtorch_url" -o libtorch.zip
169
+ unzip -q libtorch.zip -d libraries/
170
+ rm libtorch.zip
171
+ echo "libtorch downloaded and extracted."
172
+ else
173
+ echo "libtorch folder exists and is not empty. No need to download."
174
+ fi
175
+ fi
176
+
177
+ # Check if pybind11 folder exists and is not empty
178
+ if [ ! -d "$pybind11_path" ] || [ -z "$(ls -A "$pybind11_path")" ]; then
179
+ echo "pybind11 folder does not exist or is empty. Cloning the repository..."
180
+ mkdir -p libraries
181
+ git clone "$pybind11_url" "$pybind11_path"
182
+ echo "pybind11 downloaded."
183
+ cd libraries/pybind11
184
+ git reset --hard 5ccb9e4
185
+ cd ../../
186
+ echo "pybind11 reset to working build"
187
+ else
188
+ echo "pybind11 folder exists and is not empty. No need to download."
189
+ fi
190
+
191
+ # Check if midifile folder exists and is not empty
192
+ if [ ! -d "$midifile_path" ] || [ -z "$(ls -A "$midifile_path")" ]; then
193
+ echo "midifile folder does not exist or is empty. Cloning the repository..."
194
+ mkdir -p libraries
195
+ git clone "$midifile_url" "$midifile_path"
196
+ echo "midifile downloaded."
197
+ cd libraries/midifile
198
+ git reset --hard 838c62c
199
+ cd ../../
200
+ echo "midifile reset to working build"
201
+ else
202
+ echo "midifile folder exists and is not empty. No need to download."
203
+ fi
204
+
205
+ # Middle section of the script to build the python library
206
+ rm -rf ./python_lib
207
+ mkdir ./python_lib
208
+ rm -rf ./libraries/protobuf/build
209
+ mkdir ./libraries/protobuf/build
210
+
211
+ cd ./libraries/protobuf/src
212
+ protoc --cpp_out ../build *.proto
213
+ cd ../../..
214
+
215
+ cd ./python_lib
216
+
217
+ if $mac_os; then
218
+ cmake $cmake_flags .. -Dmac_os=ON -DCMAKE_PREFIX_PATH=$(python3 -c 'import torch;print(torch.utils.cmake_prefix_path)')
219
+ else
220
+ cmake $cmake_flags ..
221
+ fi
222
+ make
223
+ python3 -c "import midigpt; print('midigpt python library built successfully')"
224
+
225
+ cd ..
226
+ if $compute_canada; then
227
+ dos2unix create_dataset_compute_canada.sh
228
+ dos2unix train_dataset.sh
229
+ fi
230
+ cd ./python_lib
231
+
232
+ cd ..
data_split.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ root=$SCRATCH/datasets/GigaMIDI_Cleaned/Cleaned_Ver_EP_Class-GigaMIDI/Cleaned_GigaMIDI
4
+ new_root=$SCRATCH/datasets/GigaMIDI_Cleaned/Cleaned_Ver_EP_Class-GigaMIDI/Cleaned_GigaMIDI_Split
5
+ parent=$SCRATCH/workspace_train/parent_dir
6
+
7
+ mkdir -p $new_root
8
+ cd $new_root
9
+ mkdir -p train
10
+ mkdir -p test
11
+ mkdir -p valid
12
+
13
+ cd $SCRATCH
14
+
15
+ source $parent/venv/bin/activate
16
+ python $parent/MIDI-GPT/python_scripts/data_split.py $root $new_root
include/dataset_creation/compression/lz4.h ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * LZ4 - Fast LZ compression algorithm
3
+ * Header File
4
+ * Copyright (C) 2011-present, Yann Collet.
5
+
6
+ BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
7
+
8
+ Redistribution and use in source and binary forms, with or without
9
+ modification, are permitted provided that the following conditions are
10
+ met:
11
+
12
+ * Redistributions of source code must retain the above copyright
13
+ notice, this list of conditions and the following disclaimer.
14
+ * Redistributions in binary form must reproduce the above
15
+ copyright notice, this list of conditions and the following disclaimer
16
+ in the documentation and/or other materials provided with the
17
+ distribution.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ You can contact the author at :
32
+ - LZ4 homepage : http://www.lz4.org
33
+ - LZ4 source repository : https://github.com/lz4/lz4
34
+ */
35
+ #if defined (__cplusplus)
36
+ extern "C" {
37
+ #endif
38
+
39
+ #ifndef LZ4_H_2983827168210
40
+ #define LZ4_H_2983827168210
41
+
42
+ /* --- Dependency --- */
43
+ #include <stddef.h> /* size_t */
44
+
45
+
46
+ /**
47
+ Introduction
48
+
49
+ LZ4 is lossless compression algorithm, providing compression speed >500 MB/s per core,
50
+ scalable with multi-cores CPU. It features an extremely fast decoder, with speed in
51
+ multiple GB/s per core, typically reaching RAM speed limits on multi-core systems.
52
+
53
+ The LZ4 compression library provides in-memory compression and decompression functions.
54
+ It gives full buffer control to user.
55
+ Compression can be done in:
56
+ - a single step (described as Simple Functions)
57
+ - a single step, reusing a context (described in Advanced Functions)
58
+ - unbounded multiple steps (described as Streaming compression)
59
+
60
+ lz4.h generates and decodes LZ4-compressed blocks (doc/lz4_Block_format.md).
61
+ Decompressing such a compressed block requires additional metadata.
62
+ Exact metadata depends on exact decompression function.
63
+ For the typical case of LZ4_decompress_safe(),
64
+ metadata includes block's compressed size, and maximum bound of decompressed size.
65
+ Each application is free to encode and pass such metadata in whichever way it wants.
66
+
67
+ lz4.h only handle blocks, it can not generate Frames.
68
+
69
+ Blocks are different from Frames (doc/lz4_Frame_format.md).
70
+ Frames bundle both blocks and metadata in a specified manner.
71
+ Embedding metadata is required for compressed data to be self-contained and portable.
72
+ Frame format is delivered through a companion API, declared in lz4frame.h.
73
+ The `lz4` CLI can only manage frames.
74
+ */
75
+
76
+ /*^***************************************************************
77
+ * Export parameters
78
+ *****************************************************************/
79
+ /*
80
+ * LZ4_DLL_EXPORT :
81
+ * Enable exporting of functions when building a Windows DLL
82
+ * LZ4LIB_VISIBILITY :
83
+ * Control library symbols visibility.
84
+ */
85
+ #ifndef LZ4LIB_VISIBILITY
86
+ # if defined(__GNUC__) && (__GNUC__ >= 4)
87
+ # define LZ4LIB_VISIBILITY __attribute__ ((visibility ("default")))
88
+ # else
89
+ # define LZ4LIB_VISIBILITY
90
+ # endif
91
+ #endif
92
+ #if defined(LZ4_DLL_EXPORT) && (LZ4_DLL_EXPORT==1)
93
+ # define LZ4LIB_API __declspec(dllexport) LZ4LIB_VISIBILITY
94
+ #elif defined(LZ4_DLL_IMPORT) && (LZ4_DLL_IMPORT==1)
95
+ # define LZ4LIB_API __declspec(dllimport) LZ4LIB_VISIBILITY /* It isn't required but allows to generate better code, saving a function pointer load from the IAT and an indirect jump.*/
96
+ #else
97
+ # define LZ4LIB_API LZ4LIB_VISIBILITY
98
+ #endif
99
+
100
+ /*------ Version ------*/
101
+ #define LZ4_VERSION_MAJOR 1 /* for breaking interface changes */
102
+ #define LZ4_VERSION_MINOR 9 /* for new (non-breaking) interface capabilities */
103
+ #define LZ4_VERSION_RELEASE 2 /* for tweaks, bug-fixes, or development */
104
+
105
+ #define LZ4_VERSION_NUMBER (LZ4_VERSION_MAJOR *100*100 + LZ4_VERSION_MINOR *100 + LZ4_VERSION_RELEASE)
106
+
107
+ #define LZ4_LIB_VERSION LZ4_VERSION_MAJOR.LZ4_VERSION_MINOR.LZ4_VERSION_RELEASE
108
+ #define LZ4_QUOTE(str) #str
109
+ #define LZ4_EXPAND_AND_QUOTE(str) LZ4_QUOTE(str)
110
+ #define LZ4_VERSION_STRING LZ4_EXPAND_AND_QUOTE(LZ4_LIB_VERSION)
111
+
112
+ LZ4LIB_API int LZ4_versionNumber (void); /**< library version number; useful to check dll version */
113
+ LZ4LIB_API const char* LZ4_versionString (void); /**< library version std::string; useful to check dll version */
114
+
115
+
116
+ /*-************************************
117
+ * Tuning parameter
118
+ **************************************/
119
+ /*!
120
+ * LZ4_MEMORY_USAGE :
121
+ * Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.)
122
+ * Increasing memory usage improves compression ratio.
123
+ * Reduced memory usage may improve speed, thanks to better cache locality.
124
+ * Default value is 14, for 16KB, which nicely fits into Intel x86 L1 cache
125
+ */
126
+ #ifndef LZ4_MEMORY_USAGE
127
+ # define LZ4_MEMORY_USAGE 14
128
+ #endif
129
+
130
+
131
+ /*-************************************
132
+ * Simple Functions
133
+ **************************************/
134
+ /*! LZ4_compress_default() :
135
+ * Compresses 'srcSize' bytes from buffer 'src'
136
+ * into already allocated 'dst' buffer of size 'dstCapacity'.
137
+ * Compression is guaranteed to succeed if 'dstCapacity' >= LZ4_compressBound(srcSize).
138
+ * It also runs faster, so it's a recommended setting.
139
+ * If the function cannot compress 'src' into a more limited 'dst' budget,
140
+ * compression stops *immediately*, and the function result is zero.
141
+ * In which case, 'dst' content is undefined (invalid).
142
+ * srcSize : max supported value is LZ4_MAX_INPUT_SIZE.
143
+ * dstCapacity : size of buffer 'dst' (which must be already allocated)
144
+ * @return : the number of bytes written into buffer 'dst' (necessarily <= dstCapacity)
145
+ * or 0 if compression fails
146
+ * Note : This function is protected against buffer overflow scenarios (never writes outside 'dst' buffer, nor read outside 'source' buffer).
147
+ */
148
+ LZ4LIB_API int LZ4_compress_default(const char* src, char* dst, int srcSize, int dstCapacity);
149
+
150
+ /*! LZ4_decompress_safe() :
151
+ * compressedSize : is the exact complete size of the compressed block.
152
+ * dstCapacity : is the size of destination buffer (which must be already allocated), presumed an upper bound of decompressed size.
153
+ * @return : the number of bytes decompressed into destination buffer (necessarily <= dstCapacity)
154
+ * If destination buffer is not large enough, decoding will stop and output an error code (negative value).
155
+ * If the source stream is detected malformed, the function will stop decoding and return a negative result.
156
+ * Note 1 : This function is protected against malicious data packets :
157
+ * it will never writes outside 'dst' buffer, nor read outside 'source' buffer,
158
+ * even if the compressed block is maliciously modified to order the decoder to do these actions.
159
+ * In such case, the decoder stops immediately, and considers the compressed block malformed.
160
+ * Note 2 : compressedSize and dstCapacity must be provided to the function, the compressed block does not contain them.
161
+ * The implementation is free to send / store / derive this information in whichever way is most beneficial.
162
+ * If there is a need for a different format which bundles together both compressed data and its metadata, consider looking at lz4frame.h instead.
163
+ */
164
+ LZ4LIB_API int LZ4_decompress_safe (const char* src, char* dst, int compressedSize, int dstCapacity);
165
+
166
+
167
+ /*-************************************
168
+ * Advanced Functions
169
+ **************************************/
170
+ #define LZ4_MAX_INPUT_SIZE 0x7E000000 /* 2 113 929 216 bytes */
171
+ #define LZ4_COMPRESSBOUND(isize) ((unsigned)(isize) > (unsigned)LZ4_MAX_INPUT_SIZE ? 0 : (isize) + ((isize)/255) + 16)
172
+
173
+ /*! LZ4_compressBound() :
174
+ Provides the maximum size that LZ4 compression may output in a "worst case" scenario (input data not compressible)
175
+ This function is primarily useful for memory allocation purposes (destination buffer size).
176
+ Macro LZ4_COMPRESSBOUND() is also provided for compilation-time evaluation (stack memory allocation for example).
177
+ Note that LZ4_compress_default() compresses faster when dstCapacity is >= LZ4_compressBound(srcSize)
178
+ inputSize : max supported value is LZ4_MAX_INPUT_SIZE
179
+ return : maximum output size in a "worst case" scenario
180
+ or 0, if input size is incorrect (too large or negative)
181
+ */
182
+ LZ4LIB_API int LZ4_compressBound(int inputSize);
183
+
184
+ /*! LZ4_compress_fast() :
185
+ Same as LZ4_compress_default(), but allows selection of "acceleration" factor.
186
+ The larger the acceleration value, the faster the algorithm, but also the lesser the compression.
187
+ It's a trade-off. It can be fine tuned, with each successive value providing roughly +~3% to speed.
188
+ An acceleration value of "1" is the same as regular LZ4_compress_default()
189
+ Values <= 0 will be replaced by ACCELERATION_DEFAULT (currently == 1, see lz4.c).
190
+ */
191
+ LZ4LIB_API int LZ4_compress_fast (const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
192
+
193
+
194
+ /*! LZ4_compress_fast_extState() :
195
+ * Same as LZ4_compress_fast(), using an externally allocated memory space for its state.
196
+ * Use LZ4_sizeofState() to know how much memory must be allocated,
197
+ * and allocate it on 8-bytes boundaries (using `malloc()` typically).
198
+ * Then, provide this buffer as `void* state` to compression function.
199
+ */
200
+ LZ4LIB_API int LZ4_sizeofState(void);
201
+ LZ4LIB_API int LZ4_compress_fast_extState (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
202
+
203
+
204
+ /*! LZ4_compress_destSize() :
205
+ * Reverse the logic : compresses as much data as possible from 'src' buffer
206
+ * into already allocated buffer 'dst', of size >= 'targetDestSize'.
207
+ * This function either compresses the entire 'src' content into 'dst' if it's large enough,
208
+ * or fill 'dst' buffer completely with as much data as possible from 'src'.
209
+ * note: acceleration parameter is fixed to "default".
210
+ *
211
+ * *srcSizePtr : will be modified to indicate how many bytes where read from 'src' to fill 'dst'.
212
+ * New value is necessarily <= input value.
213
+ * @return : Nb bytes written into 'dst' (necessarily <= targetDestSize)
214
+ * or 0 if compression fails.
215
+ */
216
+ LZ4LIB_API int LZ4_compress_destSize (const char* src, char* dst, int* srcSizePtr, int targetDstSize);
217
+
218
+
219
+ /*! LZ4_decompress_safe_partial() :
220
+ * Decompress an LZ4 compressed block, of size 'srcSize' at position 'src',
221
+ * into destination buffer 'dst' of size 'dstCapacity'.
222
+ * Up to 'targetOutputSize' bytes will be decoded.
223
+ * The function stops decoding on reaching this objective,
224
+ * which can boost performance when only the beginning of a block is required.
225
+ *
226
+ * @return : the number of bytes decoded in `dst` (necessarily <= dstCapacity)
227
+ * If source stream is detected malformed, function returns a negative result.
228
+ *
229
+ * Note : @return can be < targetOutputSize, if compressed block contains less data.
230
+ *
231
+ * Note 2 : this function features 2 parameters, targetOutputSize and dstCapacity,
232
+ * and expects targetOutputSize <= dstCapacity.
233
+ * It effectively stops decoding on reaching targetOutputSize,
234
+ * so dstCapacity is kind of redundant.
235
+ * This is because in a previous version of this function,
236
+ * decoding operation would not "break" a sequence in the middle.
237
+ * As a consequence, there was no guarantee that decoding would stop at exactly targetOutputSize,
238
+ * it could write more bytes, though only up to dstCapacity.
239
+ * Some "margin" used to be required for this operation to work properly.
240
+ * This is no longer necessary.
241
+ * The function nonetheless keeps its signature, in an effort to not break API.
242
+ */
243
+ LZ4LIB_API int LZ4_decompress_safe_partial (const char* src, char* dst, int srcSize, int targetOutputSize, int dstCapacity);
244
+
245
+
246
+ /*-*********************************************
247
+ * Streaming Compression Functions
248
+ ***********************************************/
249
+ typedef union LZ4_stream_u LZ4_stream_t; /* incomplete type (defined later) */
250
+
251
+ LZ4LIB_API LZ4_stream_t* LZ4_createStream(void);
252
+ LZ4LIB_API int LZ4_freeStream (LZ4_stream_t* streamPtr);
253
+
254
+ /*! LZ4_resetStream_fast() : v1.9.0+
255
+ * Use this to prepare an LZ4_stream_t for a new chain of dependent blocks
256
+ * (e.g., LZ4_compress_fast_continue()).
257
+ *
258
+ * An LZ4_stream_t must be initialized once before usage.
259
+ * This is automatically done when created by LZ4_createStream().
260
+ * However, should the LZ4_stream_t be simply declared on stack (for example),
261
+ * it's necessary to initialize it first, using LZ4_initStream().
262
+ *
263
+ * After init, start any new stream with LZ4_resetStream_fast().
264
+ * A same LZ4_stream_t can be re-used multiple times consecutively
265
+ * and compress multiple streams,
266
+ * provided that it starts each new stream with LZ4_resetStream_fast().
267
+ *
268
+ * LZ4_resetStream_fast() is much faster than LZ4_initStream(),
269
+ * but is not compatible with memory regions containing garbage data.
270
+ *
271
+ * Note: it's only useful to call LZ4_resetStream_fast()
272
+ * in the context of streaming compression.
273
+ * The *extState* functions perform their own resets.
274
+ * Invoking LZ4_resetStream_fast() before is redundant, and even counterproductive.
275
+ */
276
+ LZ4LIB_API void LZ4_resetStream_fast (LZ4_stream_t* streamPtr);
277
+
278
+ /*! LZ4_loadDict() :
279
+ * Use this function to reference a static dictionary into LZ4_stream_t.
280
+ * The dictionary must remain available during compression.
281
+ * LZ4_loadDict() triggers a reset, so any previous data will be forgotten.
282
+ * The same dictionary will have to be loaded on decompression side for successful decoding.
283
+ * Dictionary are useful for better compression of small data (KB range).
284
+ * While LZ4 accept any input as dictionary,
285
+ * results are generally better when using Zstandard's Dictionary Builder.
286
+ * Loading a size of 0 is allowed, and is the same as reset.
287
+ * @return : loaded dictionary size, in bytes (necessarily <= 64 KB)
288
+ */
289
+ LZ4LIB_API int LZ4_loadDict (LZ4_stream_t* streamPtr, const char* dictionary, int dictSize);
290
+
291
+ /*! LZ4_compress_fast_continue() :
292
+ * Compress 'src' content using data from previously compressed blocks, for better compression ratio.
293
+ * 'dst' buffer must be already allocated.
294
+ * If dstCapacity >= LZ4_compressBound(srcSize), compression is guaranteed to succeed, and runs faster.
295
+ *
296
+ * @return : size of compressed block
297
+ * or 0 if there is an error (typically, cannot fit into 'dst').
298
+ *
299
+ * Note 1 : Each invocation to LZ4_compress_fast_continue() generates a new block.
300
+ * Each block has precise boundaries.
301
+ * Each block must be decompressed separately, calling LZ4_decompress_*() with relevant metadata.
302
+ * It's not possible to append blocks together and expect a single invocation of LZ4_decompress_*() to decompress them together.
303
+ *
304
+ * Note 2 : The previous 64KB of source data is __assumed__ to remain present, unmodified, at same address in memory !
305
+ *
306
+ * Note 3 : When input is structured as a double-buffer, each buffer can have any size, including < 64 KB.
307
+ * Make sure that buffers are separated, by at least one byte.
308
+ * This construction ensures that each block only depends on previous block.
309
+ *
310
+ * Note 4 : If input buffer is a ring-buffer, it can have any size, including < 64 KB.
311
+ *
312
+ * Note 5 : After an error, the stream status is undefined (invalid), it can only be reset or freed.
313
+ */
314
+ LZ4LIB_API int LZ4_compress_fast_continue (LZ4_stream_t* streamPtr, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
315
+
316
+ /*! LZ4_saveDict() :
317
+ * If last 64KB data cannot be guaranteed to remain available at its current memory location,
318
+ * save it into a safer place (char* safeBuffer).
319
+ * This is schematically equivalent to a memcpy() followed by LZ4_loadDict(),
320
+ * but is much faster, because LZ4_saveDict() doesn't need to rebuild tables.
321
+ * @return : saved dictionary size in bytes (necessarily <= maxDictSize), or 0 if error.
322
+ */
323
+ LZ4LIB_API int LZ4_saveDict (LZ4_stream_t* streamPtr, char* safeBuffer, int maxDictSize);
324
+
325
+
326
+ /*-**********************************************
327
+ * Streaming Decompression Functions
328
+ * Bufferless synchronous API
329
+ ************************************************/
330
+ typedef union LZ4_streamDecode_u LZ4_streamDecode_t; /* tracking context */
331
+
332
+ /*! LZ4_createStreamDecode() and LZ4_freeStreamDecode() :
333
+ * creation / destruction of streaming decompression tracking context.
334
+ * A tracking context can be re-used multiple times.
335
+ */
336
+ LZ4LIB_API LZ4_streamDecode_t* LZ4_createStreamDecode(void);
337
+ LZ4LIB_API int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream);
338
+
339
+ /*! LZ4_setStreamDecode() :
340
+ * An LZ4_streamDecode_t context can be allocated once and re-used multiple times.
341
+ * Use this function to start decompression of a new stream of blocks.
342
+ * A dictionary can optionally be set. Use NULL or size 0 for a reset order.
343
+ * Dictionary is presumed stable : it must remain accessible and unmodified during next decompression.
344
+ * @return : 1 if OK, 0 if error
345
+ */
346
+ LZ4LIB_API int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize);
347
+
348
+ /*! LZ4_decoderRingBufferSize() : v1.8.2+
349
+ * Note : in a ring buffer scenario (optional),
350
+ * blocks are presumed decompressed next to each other
351
+ * up to the moment there is not enough remaining space for next block (remainingSize < maxBlockSize),
352
+ * at which stage it resumes from beginning of ring buffer.
353
+ * When setting such a ring buffer for streaming decompression,
354
+ * provides the minimum size of this ring buffer
355
+ * to be compatible with any source respecting maxBlockSize condition.
356
+ * @return : minimum ring buffer size,
357
+ * or 0 if there is an error (invalid maxBlockSize).
358
+ */
359
+ LZ4LIB_API int LZ4_decoderRingBufferSize(int maxBlockSize);
360
+ #define LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize) (65536 + 14 + (maxBlockSize)) /* for static allocation; maxBlockSize presumed valid */
361
+
362
+ /*! LZ4_decompress_*_continue() :
363
+ * These decoding functions allow decompression of consecutive blocks in "streaming" mode.
364
+ * A block is an unsplittable entity, it must be presented entirely to a decompression function.
365
+ * Decompression functions only accepts one block at a time.
366
+ * The last 64KB of previously decoded data *must* remain available and unmodified at the memory position where they were decoded.
367
+ * If less than 64KB of data has been decoded, all the data must be present.
368
+ *
369
+ * Special : if decompression side sets a ring buffer, it must respect one of the following conditions :
370
+ * - Decompression buffer size is _at least_ LZ4_decoderRingBufferSize(maxBlockSize).
371
+ * maxBlockSize is the maximum size of any single block. It can have any value > 16 bytes.
372
+ * In which case, encoding and decoding buffers do not need to be synchronized.
373
+ * Actually, data can be produced by any source compliant with LZ4 format specification, and respecting maxBlockSize.
374
+ * - Synchronized mode :
375
+ * Decompression buffer size is _exactly_ the same as compression buffer size,
376
+ * and follows exactly same update rule (block boundaries at same positions),
377
+ * and decoding function is provided with exact decompressed size of each block (exception for last block of the stream),
378
+ * _then_ decoding & encoding ring buffer can have any size, including small ones ( < 64 KB).
379
+ * - Decompression buffer is larger than encoding buffer, by a minimum of maxBlockSize more bytes.
380
+ * In which case, encoding and decoding buffers do not need to be synchronized,
381
+ * and encoding ring buffer can have any size, including small ones ( < 64 KB).
382
+ *
383
+ * Whenever these conditions are not possible,
384
+ * save the last 64KB of decoded data into a safe buffer where it can't be modified during decompression,
385
+ * then indicate where this data is saved using LZ4_setStreamDecode(), before decompressing next block.
386
+ */
387
+ LZ4LIB_API int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int srcSize, int dstCapacity);
388
+
389
+
390
+ /*! LZ4_decompress_*_usingDict() :
391
+ * These decoding functions work the same as
392
+ * a combination of LZ4_setStreamDecode() followed by LZ4_decompress_*_continue()
393
+ * They are stand-alone, and don't need an LZ4_streamDecode_t structure.
394
+ * Dictionary is presumed stable : it must remain accessible and unmodified during decompression.
395
+ * Performance tip : Decompression speed can be substantially increased
396
+ * when dst == dictStart + dictSize.
397
+ */
398
+ LZ4LIB_API int LZ4_decompress_safe_usingDict (const char* src, char* dst, int srcSize, int dstCapcity, const char* dictStart, int dictSize);
399
+
400
+ #endif /* LZ4_H_2983827168210 */
401
+
402
+
403
+ /*^*************************************
404
+ * !!!!!! STATIC LINKING ONLY !!!!!!
405
+ ***************************************/
406
+
407
+ /*-****************************************************************************
408
+ * Experimental section
409
+ *
410
+ * Symbols declared in this section must be considered unstable. Their
411
+ * signatures or semantics may change, or they may be removed altogether in the
412
+ * future. They are therefore only safe to depend on when the caller is
413
+ * statically linked against the library.
414
+ *
415
+ * To protect against unsafe usage, not only are the declarations guarded,
416
+ * the definitions are hidden by default
417
+ * when building LZ4 as a shared/dynamic library.
418
+ *
419
+ * In order to access these declarations,
420
+ * define LZ4_STATIC_LINKING_ONLY in your application
421
+ * before including LZ4's headers.
422
+ *
423
+ * In order to make their implementations accessible dynamically, you must
424
+ * define LZ4_PUBLISH_STATIC_FUNCTIONS when building the LZ4 library.
425
+ ******************************************************************************/
426
+
427
+ #ifdef LZ4_STATIC_LINKING_ONLY
428
+
429
+ #ifndef LZ4_STATIC_3504398509
430
+ #define LZ4_STATIC_3504398509
431
+
432
+ #ifdef LZ4_PUBLISH_STATIC_FUNCTIONS
433
+ #define LZ4LIB_STATIC_API LZ4LIB_API
434
+ #else
435
+ #define LZ4LIB_STATIC_API
436
+ #endif
437
+
438
+
439
+ /*! LZ4_compress_fast_extState_fastReset() :
440
+ * A variant of LZ4_compress_fast_extState().
441
+ *
442
+ * Using this variant avoids an expensive initialization step.
443
+ * It is only safe to call if the state buffer is known to be correctly initialized already
444
+ * (see above comment on LZ4_resetStream_fast() for a definition of "correctly initialized").
445
+ * From a high level, the difference is that
446
+ * this function initializes the provided state with a call to something like LZ4_resetStream_fast()
447
+ * while LZ4_compress_fast_extState() starts with a call to LZ4_resetStream().
448
+ */
449
+ LZ4LIB_STATIC_API int LZ4_compress_fast_extState_fastReset (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
450
+
451
+ /*! LZ4_attach_dictionary() :
452
+ * This is an experimental API that allows
453
+ * efficient use of a static dictionary many times.
454
+ *
455
+ * Rather than re-loading the dictionary buffer into a working context before
456
+ * each compression, or copying a pre-loaded dictionary's LZ4_stream_t into a
457
+ * working LZ4_stream_t, this function introduces a no-copy setup mechanism,
458
+ * in which the working stream references the dictionary stream in-place.
459
+ *
460
+ * Several assumptions are made about the state of the dictionary stream.
461
+ * Currently, only streams which have been prepared by LZ4_loadDict() should
462
+ * be expected to work.
463
+ *
464
+ * Alternatively, the provided dictionaryStream may be NULL,
465
+ * in which case any existing dictionary stream is unset.
466
+ *
467
+ * If a dictionary is provided, it replaces any pre-existing stream history.
468
+ * The dictionary contents are the only history that can be referenced and
469
+ * logically immediately precede the data compressed in the first subsequent
470
+ * compression call.
471
+ *
472
+ * The dictionary will only remain attached to the working stream through the
473
+ * first compression call, at the end of which it is cleared. The dictionary
474
+ * stream (and source buffer) must remain in-place / accessible / unchanged
475
+ * through the completion of the first compression call on the stream.
476
+ */
477
+ LZ4LIB_STATIC_API void LZ4_attach_dictionary(LZ4_stream_t* workingStream, const LZ4_stream_t* dictionaryStream);
478
+
479
+
480
+ /*! In-place compression and decompression
481
+ *
482
+ * It's possible to have input and output sharing the same buffer,
483
+ * for highly contrained memory environments.
484
+ * In both cases, it requires input to lay at the end of the buffer,
485
+ * and decompression to start at beginning of the buffer.
486
+ * Buffer size must feature some margin, hence be larger than final size.
487
+ *
488
+ * |<------------------------buffer--------------------------------->|
489
+ * |<-----------compressed data--------->|
490
+ * |<-----------decompressed size------------------>|
491
+ * |<----margin---->|
492
+ *
493
+ * This technique is more useful for decompression,
494
+ * since decompressed size is typically larger,
495
+ * and margin is short.
496
+ *
497
+ * In-place decompression will work inside any buffer
498
+ * which size is >= LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize).
499
+ * This presumes that decompressedSize > compressedSize.
500
+ * Otherwise, it means compression actually expanded data,
501
+ * and it would be more efficient to store such data with a flag indicating it's not compressed.
502
+ * This can happen when data is not compressible (already compressed, or encrypted).
503
+ *
504
+ * For in-place compression, margin is larger, as it must be able to cope with both
505
+ * history preservation, requiring input data to remain unmodified up to LZ4_DISTANCE_MAX,
506
+ * and data expansion, which can happen when input is not compressible.
507
+ * As a consequence, buffer size requirements are much higher,
508
+ * and memory savings offered by in-place compression are more limited.
509
+ *
510
+ * There are ways to limit this cost for compression :
511
+ * - Reduce history size, by modifying LZ4_DISTANCE_MAX.
512
+ * Note that it is a compile-time constant, so all compressions will apply this limit.
513
+ * Lower values will reduce compression ratio, except when input_size < LZ4_DISTANCE_MAX,
514
+ * so it's a reasonable trick when inputs are known to be small.
515
+ * - Require the compressor to deliver a "maximum compressed size".
516
+ * This is the `dstCapacity` parameter in `LZ4_compress*()`.
517
+ * When this size is < LZ4_COMPRESSBOUND(inputSize), then compression can fail,
518
+ * in which case, the return code will be 0 (zero).
519
+ * The caller must be ready for these cases to happen,
520
+ * and typically design a backup scheme to send data uncompressed.
521
+ * The combination of both techniques can significantly reduce
522
+ * the amount of margin required for in-place compression.
523
+ *
524
+ * In-place compression can work in any buffer
525
+ * which size is >= (maxCompressedSize)
526
+ * with maxCompressedSize == LZ4_COMPRESSBOUND(srcSize) for guaranteed compression success.
527
+ * LZ4_COMPRESS_INPLACE_BUFFER_SIZE() depends on both maxCompressedSize and LZ4_DISTANCE_MAX,
528
+ * so it's possible to reduce memory requirements by playing with them.
529
+ */
530
+
531
+ #define LZ4_DECOMPRESS_INPLACE_MARGIN(compressedSize) (((compressedSize) >> 8) + 32)
532
+ #define LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize) ((decompressedSize) + LZ4_DECOMPRESS_INPLACE_MARGIN(decompressedSize)) /**< note: presumes that compressedSize < decompressedSize. note2: margin is overestimated a bit, since it could use compressedSize instead */
533
+
534
+ #ifndef LZ4_DISTANCE_MAX /* history window size; can be user-defined at compile time */
535
+ # define LZ4_DISTANCE_MAX 65535 /* set to maximum value by default */
536
+ #endif
537
+
538
+ #define LZ4_COMPRESS_INPLACE_MARGIN (LZ4_DISTANCE_MAX + 32) /* LZ4_DISTANCE_MAX can be safely replaced by srcSize when it's smaller */
539
+ #define LZ4_COMPRESS_INPLACE_BUFFER_SIZE(maxCompressedSize) ((maxCompressedSize) + LZ4_COMPRESS_INPLACE_MARGIN) /**< maxCompressedSize is generally LZ4_COMPRESSBOUND(inputSize), but can be set to any lower value, with the risk that compression can fail (return code 0(zero)) */
540
+
541
+ #endif /* LZ4_STATIC_3504398509 */
542
+ #endif /* LZ4_STATIC_LINKING_ONLY */
543
+
544
+
545
+
546
+ #ifndef LZ4_H_98237428734687
547
+ #define LZ4_H_98237428734687
548
+
549
+ /*-************************************************************
550
+ * PRIVATE DEFINITIONS
551
+ **************************************************************
552
+ * Do not use these definitions directly.
553
+ * They are only exposed to allow static allocation of `LZ4_stream_t` and `LZ4_streamDecode_t`.
554
+ * Accessing members will expose code to API and/or ABI break in future versions of the library.
555
+ **************************************************************/
556
+ #define LZ4_HASHLOG (LZ4_MEMORY_USAGE-2)
557
+ #define LZ4_HASHTABLESIZE (1 << LZ4_MEMORY_USAGE)
558
+ #define LZ4_HASH_SIZE_U32 (1 << LZ4_HASHLOG) /* required as macro for static allocation */
559
+
560
+ #if defined(__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
561
+ #include <stdint.h>
562
+
563
+ typedef struct LZ4_stream_t_internal LZ4_stream_t_internal;
564
+ struct LZ4_stream_t_internal {
565
+ uint32_t hashTable[LZ4_HASH_SIZE_U32];
566
+ uint32_t currentOffset;
567
+ uint16_t dirty;
568
+ uint16_t tableType;
569
+ const uint8_t* dictionary;
570
+ const LZ4_stream_t_internal* dictCtx;
571
+ uint32_t dictSize;
572
+ };
573
+
574
+ typedef struct {
575
+ const uint8_t* externalDict;
576
+ size_t extDictSize;
577
+ const uint8_t* prefixEnd;
578
+ size_t prefixSize;
579
+ } LZ4_streamDecode_t_internal;
580
+
581
+ #else
582
+
583
+ typedef struct LZ4_stream_t_internal LZ4_stream_t_internal;
584
+ struct LZ4_stream_t_internal {
585
+ unsigned int hashTable[LZ4_HASH_SIZE_U32];
586
+ unsigned int currentOffset;
587
+ unsigned short dirty;
588
+ unsigned short tableType;
589
+ const unsigned char* dictionary;
590
+ const LZ4_stream_t_internal* dictCtx;
591
+ unsigned int dictSize;
592
+ };
593
+
594
+ typedef struct {
595
+ const unsigned char* externalDict;
596
+ const unsigned char* prefixEnd;
597
+ size_t extDictSize;
598
+ size_t prefixSize;
599
+ } LZ4_streamDecode_t_internal;
600
+
601
+ #endif
602
+
603
+ /*! LZ4_stream_t :
604
+ * information structure to track an LZ4 stream.
605
+ * LZ4_stream_t can also be created using LZ4_createStream(), which is recommended.
606
+ * The structure definition can be convenient for static allocation
607
+ * (on stack, or as part of larger structure).
608
+ * Init this structure with LZ4_initStream() before first use.
609
+ * note : only use this definition in association with static linking !
610
+ * this definition is not API/ABI safe, and may change in a future version.
611
+ */
612
+ #define LZ4_STREAMSIZE_U64 ((1 << (LZ4_MEMORY_USAGE-3)) + 4 + ((sizeof(void*)==16) ? 4 : 0) /*AS-400*/ )
613
+ #define LZ4_STREAMSIZE (LZ4_STREAMSIZE_U64 * sizeof(unsigned long long))
614
+ union LZ4_stream_u {
615
+ unsigned long long table[LZ4_STREAMSIZE_U64];
616
+ LZ4_stream_t_internal internal_donotuse;
617
+ } ; /* previously typedef'd to LZ4_stream_t */
618
+
619
+ /*! LZ4_initStream() : v1.9.0+
620
+ * An LZ4_stream_t structure must be initialized at least once.
621
+ * This is automatically done when invoking LZ4_createStream(),
622
+ * but it's not when the structure is simply declared on stack (for example).
623
+ *
624
+ * Use LZ4_initStream() to properly initialize a newly declared LZ4_stream_t.
625
+ * It can also initialize any arbitrary buffer of sufficient size,
626
+ * and will @return a pointer of proper type upon initialization.
627
+ *
628
+ * Note : initialization fails if size and alignment conditions are not respected.
629
+ * In which case, the function will @return NULL.
630
+ * Note2: An LZ4_stream_t structure guarantees correct alignment and size.
631
+ * Note3: Before v1.9.0, use LZ4_resetStream() instead
632
+ */
633
+ LZ4LIB_API LZ4_stream_t* LZ4_initStream (void* buffer, size_t size);
634
+
635
+
636
+ /*! LZ4_streamDecode_t :
637
+ * information structure to track an LZ4 stream during decompression.
638
+ * init this structure using LZ4_setStreamDecode() before first use.
639
+ * note : only use in association with static linking !
640
+ * this definition is not API/ABI safe,
641
+ * and may change in a future version !
642
+ */
643
+ #define LZ4_STREAMDECODESIZE_U64 (4 + ((sizeof(void*)==16) ? 2 : 0) /*AS-400*/ )
644
+ #define LZ4_STREAMDECODESIZE (LZ4_STREAMDECODESIZE_U64 * sizeof(unsigned long long))
645
+ union LZ4_streamDecode_u {
646
+ unsigned long long table[LZ4_STREAMDECODESIZE_U64];
647
+ LZ4_streamDecode_t_internal internal_donotuse;
648
+ } ; /* previously typedef'd to LZ4_streamDecode_t */
649
+
650
+
651
+
652
+ /*-************************************
653
+ * Obsolete Functions
654
+ **************************************/
655
+
656
+ /*! Deprecation warnings
657
+ *
658
+ * Deprecated functions make the compiler generate a warning when invoked.
659
+ * This is meant to invite users to update their source code.
660
+ * Should deprecation warnings be a problem, it is generally possible to disable them,
661
+ * typically with -Wno-deprecated-declarations for gcc
662
+ * or _CRT_SECURE_NO_WARNINGS in Visual.
663
+ *
664
+ * Another method is to define LZ4_DISABLE_DEPRECATE_WARNINGS
665
+ * before including the header file.
666
+ */
667
+ #ifdef LZ4_DISABLE_DEPRECATE_WARNINGS
668
+ # define LZ4_DEPRECATED(message) /* disable deprecation warnings */
669
+ #else
670
+ # define LZ4_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)
671
+ # if defined (__cplusplus) && (__cplusplus >= 201402) /* C++14 or greater */
672
+ # define LZ4_DEPRECATED(message) [[deprecated(message)]]
673
+ # elif (LZ4_GCC_VERSION >= 405) || defined(__clang__)
674
+ # define LZ4_DEPRECATED(message) __attribute__((deprecated(message)))
675
+ # elif (LZ4_GCC_VERSION >= 301)
676
+ # define LZ4_DEPRECATED(message) __attribute__((deprecated))
677
+ # elif defined(_MSC_VER)
678
+ # define LZ4_DEPRECATED(message) __declspec(deprecated(message))
679
+ # else
680
+ # pragma message("WARNING: You need to implement LZ4_DEPRECATED for this compiler")
681
+ # define LZ4_DEPRECATED(message)
682
+ # endif
683
+ #endif /* LZ4_DISABLE_DEPRECATE_WARNINGS */
684
+
685
+ /* Obsolete compression functions */
686
+ LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress (const char* src, char* dest, int srcSize);
687
+ LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress_limitedOutput (const char* src, char* dest, int srcSize, int maxOutputSize);
688
+ LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_withState (void* state, const char* source, char* dest, int inputSize);
689
+ LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_limitedOutput_withState (void* state, const char* source, char* dest, int inputSize, int maxOutputSize);
690
+ LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize);
691
+ LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize, int maxOutputSize);
692
+
693
+ /* Obsolete decompression functions */
694
+ LZ4_DEPRECATED("use LZ4_decompress_fast() instead") LZ4LIB_API int LZ4_uncompress (const char* source, char* dest, int outputSize);
695
+ LZ4_DEPRECATED("use LZ4_decompress_safe() instead") LZ4LIB_API int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize);
696
+
697
+ /* Obsolete streaming functions; degraded functionality; do not use!
698
+ *
699
+ * In order to perform streaming compression, these functions depended on data
700
+ * that is no longer tracked in the state. They have been preserved as well as
701
+ * possible: using them will still produce a correct output. However, they don't
702
+ * actually retain any history between compression calls. The compression ratio
703
+ * achieved will therefore be no better than compressing each chunk
704
+ * independently.
705
+ */
706
+ LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API void* LZ4_create (char* inputBuffer);
707
+ LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API int LZ4_sizeofStreamState(void);
708
+ LZ4_DEPRECATED("Use LZ4_resetStream() instead") LZ4LIB_API int LZ4_resetStreamState(void* state, char* inputBuffer);
709
+ LZ4_DEPRECATED("Use LZ4_saveDict() instead") LZ4LIB_API char* LZ4_slideInputBuffer (void* state);
710
+
711
+ /* Obsolete streaming decoding functions */
712
+ LZ4_DEPRECATED("use LZ4_decompress_safe_usingDict() instead") LZ4LIB_API int LZ4_decompress_safe_withPrefix64k (const char* src, char* dst, int compressedSize, int maxDstSize);
713
+ LZ4_DEPRECATED("use LZ4_decompress_fast_usingDict() instead") LZ4LIB_API int LZ4_decompress_fast_withPrefix64k (const char* src, char* dst, int originalSize);
714
+
715
+ /*! LZ4_decompress_fast() : **unsafe!**
716
+ * These functions used to be faster than LZ4_decompress_safe(),
717
+ * but it has changed, and they are now slower than LZ4_decompress_safe().
718
+ * This is because LZ4_decompress_fast() doesn't know the input size,
719
+ * and therefore must progress more cautiously in the input buffer to not read beyond the end of block.
720
+ * On top of that `LZ4_decompress_fast()` is not protected vs malformed or malicious inputs, making it a security liability.
721
+ * As a consequence, LZ4_decompress_fast() is strongly discouraged, and deprecated.
722
+ *
723
+ * The last remaining LZ4_decompress_fast() specificity is that
724
+ * it can decompress a block without knowing its compressed size.
725
+ * Such functionality could be achieved in a more secure manner,
726
+ * by also providing the maximum size of input buffer,
727
+ * but it would require new prototypes, and adaptation of the implementation to this new use case.
728
+ *
729
+ * Parameters:
730
+ * originalSize : is the uncompressed size to regenerate.
731
+ * `dst` must be already allocated, its size must be >= 'originalSize' bytes.
732
+ * @return : number of bytes read from source buffer (== compressed size).
733
+ * The function expects to finish at block's end exactly.
734
+ * If the source stream is detected malformed, the function stops decoding and returns a negative result.
735
+ * note : LZ4_decompress_fast*() requires originalSize. Thanks to this information, it never writes past the output buffer.
736
+ * However, since it doesn't know its 'src' size, it may read an unknown amount of input, past input buffer bounds.
737
+ * Also, since match offsets are not validated, match reads from 'src' may underflow too.
738
+ * These issues never happen if input (compressed) data is correct.
739
+ * But they may happen if input data is invalid (error or intentional tampering).
740
+ * As a consequence, use these functions in trusted environments with trusted data **only**.
741
+ */
742
+
743
+ LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe() instead")
744
+ LZ4LIB_API int LZ4_decompress_fast (const char* src, char* dst, int originalSize);
745
+ LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_continue() instead")
746
+ LZ4LIB_API int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int originalSize);
747
+ LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_usingDict() instead")
748
+ LZ4LIB_API int LZ4_decompress_fast_usingDict (const char* src, char* dst, int originalSize, const char* dictStart, int dictSize);
749
+
750
+ /*! LZ4_resetStream() :
751
+ * An LZ4_stream_t structure must be initialized at least once.
752
+ * This is done with LZ4_initStream(), or LZ4_resetStream().
753
+ * Consider switching to LZ4_initStream(),
754
+ * invoking LZ4_resetStream() will trigger deprecation warnings in the future.
755
+ */
756
+ LZ4LIB_API void LZ4_resetStream (LZ4_stream_t* streamPtr);
757
+
758
+
759
+ #endif /* LZ4_H_98237428734687 */
760
+
761
+
762
+ #if defined (__cplusplus)
763
+ }
764
+ #endif
include/dataset_creation/dataset_manipulation/bytes_to_file.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <string>
4
+ #include <fstream>
5
+ #include "../../../libraries/protobuf/build/midi.pb.h"
6
+ #include "../compression/lz4.h"
7
+
8
+
9
+ namespace dataset_manipulation {
10
+ class BytesToFile{
11
+ private:
12
+ std::string filepath_;
13
+ std::string header_filepath_;
14
+ std::fstream file_stream_;
15
+ std::fstream header_file_stream_;
16
+ midi::Dataset dataset_split_protobuf_;
17
+ int flush_count_;
18
+ bool can_write;
19
+
20
+ public:
21
+ BytesToFile(std::string external_filepath_);
22
+ void enableWrite();
23
+ void appendBytesToFileStream(std::string& bytes_as_string, size_t split_id);
24
+ void writeFile();
25
+ void close();
26
+ };
27
+ }
libraries/midifile ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 838c62c4a13245ced8e13a84e6c2a1994664acd5
libraries/protobuf/CMakeLists.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.8)
2
+
3
+ project(midigpt_proto)
4
+
5
+ set(PROTO_DEF
6
+ src/enum.proto
7
+ src/midi.proto
8
+ src/midi_internal.proto
9
+ src/track_type.proto
10
+ src/feature_extraction.proto)
11
+
12
+ find_package(Protobuf REQUIRED)
13
+
14
+ protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS
15
+ ${PROTO_DEF}
16
+ PROTOC_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} # it's the default but it does not hurt to be explicit here...
17
+ )
18
+
19
+ add_library(midigpt_proto
20
+ ${PROTO_SRCS}
21
+ ${PROTO_HDRS})
22
+
23
+ target_include_directories(midigpt_proto
24
+ PUBLIC
25
+ ${Protobuf_INCLUDE_DIRS}
26
+ ${CMAKE_CURRENT_BINARY_DIR} # for generated protobuf files
27
+ )
28
+ target_link_libraries(midigpt_proto ${Protobuf_LIBRARIES})
libraries/protobuf/CMakeSettings.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "configurations": [
3
+ {
4
+ "name": "x64-Debug",
5
+ "generator": "Ninja",
6
+ "configurationType": "Debug",
7
+ "inheritEnvironments": [ "msvc_x64_x64" ],
8
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
9
+ "installRoot": "${projectDir}\\out\\install\\${name}",
10
+ "cmakeCommandArgs": "",
11
+ "buildCommandArgs": "",
12
+ "ctestCommandArgs": ""
13
+ },
14
+ {
15
+ "name": "WSL-GCC-Debug",
16
+ "generator": "Ninja",
17
+ "configurationType": "Debug",
18
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
19
+ "installRoot": "${projectDir}\\out\\install\\${name}",
20
+ "cmakeExecutable": "cmake",
21
+ "cmakeCommandArgs": "",
22
+ "buildCommandArgs": "",
23
+ "ctestCommandArgs": "",
24
+ "inheritEnvironments": [ "linux_x64" ],
25
+ "wslPath": "${defaultWSLPath}",
26
+ "variables": []
27
+ }
28
+ ]
29
+ }
libraries/protobuf/include/proto_library.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace midigptProto {
4
+ void testMmmProto();
5
+ }
libraries/protobuf/src/enum.proto ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto2";
2
+
3
+ package midi;
4
+
5
+ enum TOKEN_TYPE {
6
+ TOKEN_PIECE_START = 0;
7
+ TOKEN_NOTE_ONSET = 1;
8
+ TOKEN_PITCH = 3;
9
+ TOKEN_VELOCITY = 5;
10
+ TOKEN_TIME_ABSOLUTE_POS = 7;
11
+ TOKEN_INSTRUMENT = 8;
12
+ TOKEN_BAR = 9;
13
+ TOKEN_BAR_END = 10;
14
+ TOKEN_TRACK = 11;
15
+ TOKEN_TRACK_END = 12;
16
+ TOKEN_DRUM_TRACK = 13;
17
+ TOKEN_FILL_IN = 14;
18
+ TOKEN_FILL_IN_PLACEHOLDER = 15;
19
+ TOKEN_FILL_IN_START = 16;
20
+ TOKEN_FILL_IN_END = 17;
21
+ TOKEN_VELOCITY_LEVEL = 19;
22
+ TOKEN_GENRE = 20;
23
+ TOKEN_DENSITY_LEVEL = 21;
24
+ TOKEN_TIME_SIGNATURE = 22;
25
+ TOKEN_NOTE_DURATION = 26;
26
+ TOKEN_AV_POLYPHONY = 27;
27
+ TOKEN_MIN_POLYPHONY = 28;
28
+ TOKEN_MAX_POLYPHONY = 29;
29
+ TOKEN_MIN_NOTE_DURATION = 30;
30
+ TOKEN_MAX_NOTE_DURATION = 31;
31
+ TOKEN_NUM_BARS = 32;
32
+ TOKEN_MIN_POLYPHONY_HARD = 33;
33
+ TOKEN_MAX_POLYPHONY_HARD = 34;
34
+ TOKEN_MIN_NOTE_DURATION_HARD = 35;
35
+ TOKEN_MAX_NOTE_DURATION_HARD = 36;
36
+ TOKEN_BAR_LEVEL_ONSET_DENSITY = 40;
37
+ TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MIN = 41;
38
+ TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MAX = 42;
39
+ TOKEN_TRACK_LEVEL_ONSET_DENSITY = 43;
40
+ TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MIN = 44;
41
+ TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MAX = 45;
42
+ TOKEN_TRACK_LEVEL_ONSET_DENSITY_MIN = 46;
43
+ TOKEN_TRACK_LEVEL_ONSET_DENSITY_MAX = 47;
44
+ TOKEN_TRACK_LEVEL_PITCH_RANGE_MIN = 48;
45
+ TOKEN_TRACK_LEVEL_PITCH_RANGE_MAX = 49;
46
+ TOKEN_CONTAINS_NOTE_DURATION_THIRTY_SECOND = 59;
47
+ TOKEN_CONTAINS_NOTE_DURATION_SIXTEENTH = 60;
48
+ TOKEN_CONTAINS_NOTE_DURATION_EIGHTH = 61;
49
+ TOKEN_CONTAINS_NOTE_DURATION_QUARTER = 62;
50
+ TOKEN_CONTAINS_NOTE_DURATION_HALF = 63;
51
+ TOKEN_CONTAINS_NOTE_DURATION_WHOLE = 64;
52
+ TOKEN_DELTA = 67;
53
+ TOKEN_DELTA_DIRECTION = 68;
54
+ TOKEN_NONE = 72;
55
+ }
56
+
57
+ enum ATTRIBUTE_CONTROL_TYPE {
58
+ ATTRIBUTE_CONTROL_NOTE_DENSITY = 0;
59
+ ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_POLYPHONY = 1;
60
+ ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_DENSITY = 2;
61
+ ATTRIBUTE_CONTROL_PITCH_RANGE = 3;
62
+ ATTRIBUTE_CONTROL_GENRE = 4;
63
+ ATTRIBUTE_CONTROL_POLYPHONY_QUANTILE = 5;
64
+ ATTRIBUTE_CONTROL_NOTE_DURATION_QUANTILE = 6;
65
+ ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_DENSITY = 7;
66
+ ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_POLYPHONY = 8;
67
+ ATTRIBUTE_CONTROL_TRACK_LEVEL_NOTE_DURATION = 9;
68
+ ATTRIBUTE_CONTROL_END = 10;
69
+ }
70
+
71
+ enum GenreMusicmap {
72
+ GENRE_MUSICMAP_ANY = 0;
73
+ GENRE_MUSICMAP_ALTERNATIVE_ROCK = 1;
74
+ GENRE_MUSICMAP_AMBIENT = 2;
75
+ GENRE_MUSICMAP_BLUES = 3;
76
+ GENRE_MUSICMAP_BREAKBEAT = 4;
77
+ GENRE_MUSICMAP_CLASSICAL = 5;
78
+ GENRE_MUSICMAP_CLASSIC_ROCK = 6;
79
+ GENRE_MUSICMAP_CONTEMPORARY_ROCK = 7;
80
+ GENRE_MUSICMAP_COUNTRY = 8;
81
+ GENRE_MUSICMAP_DRUM_N_BASS = 9;
82
+ GENRE_MUSICMAP_FOLK = 10;
83
+ GENRE_MUSICMAP_GOSPEL = 11;
84
+ GENRE_MUSICMAP_HARDCORE_PUNK = 12;
85
+ GENRE_MUSICMAP_HARDCORE_TECHNO = 13;
86
+ GENRE_MUSICMAP_HEAVY_METAL = 14;
87
+ GENRE_MUSICMAP_HIP_HOP = 15;
88
+ GENRE_MUSICMAP_HOUSE = 16;
89
+ GENRE_MUSICMAP_INDUSTRIAL = 17;
90
+ GENRE_MUSICMAP_JAZZ = 18;
91
+ GENRE_MUSICMAP_LATIN = 19;
92
+ GENRE_MUSICMAP_POP = 20;
93
+ GENRE_MUSICMAP_PUNK = 21;
94
+ GENRE_MUSICMAP_PUNK_ROCK = 22;
95
+ GENRE_MUSICMAP_RANDB = 23;
96
+ GENRE_MUSICMAP_REGGAE = 24;
97
+ GENRE_MUSICMAP_ROCK_N_ROLL = 25;
98
+ GENRE_MUSICMAP_TECHNO = 26;
99
+ GENRE_MUSICMAP_TRANCE = 27;
100
+ GENRE_MUSICMAP_UTILITY = 28;
101
+ GENRE_MUSICMAP_WORLD = 29;
102
+ GENRE_MUSICMAP_NONE = 30;
103
+ };
104
+
105
+ enum GM_CATEGORY {
106
+ GM_CATEGORY_MONO = 0;
107
+ GM_CATEGORY_POLY = 1;
108
+ GM_CATEGORY_SOUND_FX = 2;
109
+ GM_CATEGORY_PERC = 3;
110
+ };
111
+
112
+ enum GM_TYPE {
113
+ any = 0;
114
+ piano = 1;
115
+ chromatic_perc = 2;
116
+ organ = 3;
117
+ guitar = 4;
118
+ bass = 5;
119
+ strings = 6;
120
+ ensemble = 7;
121
+ brass = 8;
122
+ reed = 9;
123
+ pipe = 10;
124
+ synth_lead = 11;
125
+ synth_pad = 12;
126
+ synth_effects = 13;
127
+ ethnic = 14;
128
+ percussive = 15;
129
+ sound_fx = 16;
130
+ no_drums = 17;
131
+ drums = 18;
132
+ acoustic_grand_piano = 19;
133
+ bright_acoustic_piano = 20;
134
+ electric_grand_piano = 21;
135
+ honky_tonk_piano = 22;
136
+ electric_piano_1 = 23;
137
+ electric_piano_2 = 24;
138
+ harpsichord = 25;
139
+ clavi = 26;
140
+ celesta = 27;
141
+ glockenspiel = 28;
142
+ music_box = 29;
143
+ vibraphone = 30;
144
+ marimba = 31;
145
+ xylophone = 32;
146
+ tubular_bells = 33;
147
+ dulcimer = 34;
148
+ drawbar_organ = 35;
149
+ percussive_organ = 36;
150
+ rock_organ = 37;
151
+ church_organ = 38;
152
+ reed_organ = 39;
153
+ accordion = 40;
154
+ harmonica = 41;
155
+ tango_accordion = 42;
156
+ acoustic_guitar_nylon = 43;
157
+ acoustic_guitar_steel = 44;
158
+ electric_guitar_jazz = 45;
159
+ electric_guitar_clean = 46;
160
+ electric_guitar_muted = 47;
161
+ overdriven_guitar = 48;
162
+ distortion_guitar = 49;
163
+ guitar_harmonics = 50;
164
+ acoustic_bass = 51;
165
+ electric_bass_finger = 52;
166
+ electric_bass_pick = 53;
167
+ fretless_bass = 54;
168
+ slap_bass_1 = 55;
169
+ slap_bass_2 = 56;
170
+ synth_bass_1 = 57;
171
+ synth_bass_2 = 58;
172
+ violin = 59;
173
+ viola = 60;
174
+ cello = 61;
175
+ contrabass = 62;
176
+ tremolo_strings = 63;
177
+ pizzicato_strings = 64;
178
+ orchestral_harp = 65;
179
+ timpani = 66;
180
+ string_ensemble_1 = 67;
181
+ string_ensemble_2 = 68;
182
+ synth_strings_1 = 69;
183
+ synth_strings_2 = 70;
184
+ choir_aahs = 71;
185
+ voice_oohs = 72;
186
+ synth_voice = 73;
187
+ orchestra_hit = 74;
188
+ trumpet = 75;
189
+ trombone = 76;
190
+ tuba = 77;
191
+ muted_trumpet = 78;
192
+ french_horn = 79;
193
+ brass_section = 80;
194
+ synth_brass_1 = 81;
195
+ synth_brass_2 = 82;
196
+ soprano_sax = 83;
197
+ alto_sax = 84;
198
+ tenor_sax = 85;
199
+ baritone_sax = 86;
200
+ oboe = 87;
201
+ english_horn = 88;
202
+ bassoon = 89;
203
+ clarinet = 90;
204
+ piccolo = 91;
205
+ flute = 92;
206
+ recorder = 93;
207
+ pan_flute = 94;
208
+ blown_bottle = 95;
209
+ shakuhachi = 96;
210
+ whistle = 97;
211
+ ocarina = 98;
212
+ lead_1_square = 99;
213
+ lead_2_sawtooth = 100;
214
+ lead_3_calliope = 101;
215
+ lead_4_chiff = 102;
216
+ lead_5_charang = 103;
217
+ lead_6_voice = 104;
218
+ lead_7_fifths = 105;
219
+ lead_8_bass__lead = 106;
220
+ pad_1_new_age = 107;
221
+ pad_2_warm = 108;
222
+ pad_3_polysynth = 109;
223
+ pad_4_choir = 110;
224
+ pad_5_bowed = 111;
225
+ pad_6_metallic = 112;
226
+ pad_7_halo = 113;
227
+ pad_8_sweep = 114;
228
+ fx_1_rain = 115;
229
+ fx_2_soundtrack = 116;
230
+ fx_3_crystal = 117;
231
+ fx_4_atmosphere = 118;
232
+ fx_5_brightness = 119;
233
+ fx_6_goblins = 120;
234
+ fx_7_echoes = 121;
235
+ fx_8_sci_fi = 122;
236
+ sitar = 123;
237
+ banjo = 124;
238
+ shamisen = 125;
239
+ koto = 126;
240
+ kalimba = 127;
241
+ bag_pipe = 128;
242
+ fiddle = 129;
243
+ shanai = 130;
244
+ tinkle_bell = 131;
245
+ agogo = 132;
246
+ steel_drums = 133;
247
+ woodblock = 134;
248
+ taiko_drum = 135;
249
+ melodic_tom = 136;
250
+ synth_drum = 137;
251
+ reverse_cymbal = 138;
252
+ guitar_fret_noise = 139;
253
+ breath_noise = 140;
254
+ seashore = 141;
255
+ bird_tweet = 142;
256
+ telephone_ring = 143;
257
+ helicopter = 144;
258
+ applause = 145;
259
+ gunshot = 146;
260
+ drum_0 = 147;
261
+ drum_1 = 148;
262
+ drum_2 = 149;
263
+ drum_3 = 150;
264
+ drum_4 = 151;
265
+ drum_5 = 152;
266
+ drum_6 = 153;
267
+ drum_7 = 154;
268
+ drum_8 = 155;
269
+ drum_9 = 156;
270
+ drum_10 = 157;
271
+ drum_11 = 158;
272
+ drum_12 = 159;
273
+ drum_13 = 160;
274
+ drum_14 = 161;
275
+ drum_15 = 162;
276
+ drum_16 = 163;
277
+ drum_17 = 164;
278
+ drum_18 = 165;
279
+ drum_19 = 166;
280
+ drum_20 = 167;
281
+ drum_21 = 168;
282
+ drum_22 = 169;
283
+ drum_23 = 170;
284
+ drum_24 = 171;
285
+ drum_25 = 172;
286
+ drum_26 = 173;
287
+ drum_27 = 174;
288
+ drum_28 = 175;
289
+ drum_29 = 176;
290
+ drum_30 = 177;
291
+ drum_31 = 178;
292
+ drum_32 = 179;
293
+ drum_33 = 180;
294
+ drum_34 = 181;
295
+ drum_35 = 182;
296
+ drum_36 = 183;
297
+ drum_37 = 184;
298
+ drum_38 = 185;
299
+ drum_39 = 186;
300
+ drum_40 = 187;
301
+ drum_41 = 188;
302
+ drum_42 = 189;
303
+ drum_43 = 190;
304
+ drum_44 = 191;
305
+ drum_45 = 192;
306
+ drum_46 = 193;
307
+ drum_47 = 194;
308
+ drum_48 = 195;
309
+ drum_49 = 196;
310
+ drum_50 = 197;
311
+ drum_51 = 198;
312
+ drum_52 = 199;
313
+ drum_53 = 200;
314
+ drum_54 = 201;
315
+ drum_55 = 202;
316
+ drum_56 = 203;
317
+ drum_57 = 204;
318
+ drum_58 = 205;
319
+ drum_59 = 206;
320
+ drum_60 = 207;
321
+ drum_61 = 208;
322
+ drum_62 = 209;
323
+ drum_63 = 210;
324
+ drum_64 = 211;
325
+ drum_65 = 212;
326
+ drum_66 = 213;
327
+ drum_67 = 214;
328
+ drum_68 = 215;
329
+ drum_69 = 216;
330
+ drum_70 = 217;
331
+ drum_71 = 218;
332
+ drum_72 = 219;
333
+ drum_73 = 220;
334
+ drum_74 = 221;
335
+ drum_75 = 222;
336
+ drum_76 = 223;
337
+ drum_77 = 224;
338
+ drum_78 = 225;
339
+ drum_79 = 226;
340
+ drum_80 = 227;
341
+ drum_81 = 228;
342
+ drum_82 = 229;
343
+ drum_83 = 230;
344
+ drum_84 = 231;
345
+ drum_85 = 232;
346
+ drum_86 = 233;
347
+ drum_87 = 234;
348
+ drum_88 = 235;
349
+ drum_89 = 236;
350
+ drum_90 = 237;
351
+ drum_91 = 238;
352
+ drum_92 = 239;
353
+ drum_93 = 240;
354
+ drum_94 = 241;
355
+ drum_95 = 242;
356
+ drum_96 = 243;
357
+ drum_97 = 244;
358
+ drum_98 = 245;
359
+ drum_99 = 246;
360
+ drum_100 = 247;
361
+ drum_101 = 248;
362
+ drum_102 = 249;
363
+ drum_103 = 250;
364
+ drum_104 = 251;
365
+ drum_105 = 252;
366
+ drum_106 = 253;
367
+ drum_107 = 254;
368
+ drum_108 = 255;
369
+ drum_109 = 256;
370
+ drum_110 = 257;
371
+ drum_111 = 258;
372
+ drum_112 = 259;
373
+ drum_113 = 260;
374
+ drum_114 = 261;
375
+ drum_115 = 262;
376
+ drum_116 = 263;
377
+ drum_117 = 264;
378
+ drum_118 = 265;
379
+ drum_119 = 266;
380
+ drum_120 = 267;
381
+ drum_121 = 268;
382
+ drum_122 = 269;
383
+ drum_123 = 270;
384
+ drum_124 = 271;
385
+ drum_125 = 272;
386
+ drum_126 = 273;
387
+ drum_127 = 274;
388
+ };
389
+
libraries/protobuf/src/feature_extraction.proto ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto2";
2
+
3
+ package midi;
4
+
5
+ message PitchRange {
6
+ optional int32 instrument = 1;
7
+ optional int32 min = 2;
8
+ optional int32 max = 3;
9
+ }
10
+
11
+ message MetricDepth {
12
+ optional string filepath = 1;
13
+ optional int32 track_num = 2;
14
+ optional int32 instrument = 3;
15
+ optional bool is_drum = 4;
16
+ optional bool has_time_signatures = 5;
17
+ repeated int32 metric_depth = 6;
18
+ optional int32 tpq = 7;
19
+ }
20
+
21
+ message MedianMetricDepth {
22
+ optional string filepath = 1;
23
+ optional int32 track_num = 2;
24
+ optional int32 instrument = 3;
25
+ optional bool is_drum = 4;
26
+ optional bool has_time_signatures = 5;
27
+ optional int32 median_metric_depth = 6;
28
+ optional int32 tpq = 7;
29
+ }
30
+
31
+ message MostFrequentMetricDepth {
32
+ optional string filepath = 1;
33
+ optional int32 track_num = 2;
34
+ optional int32 instrument = 3;
35
+ optional bool is_drum = 4;
36
+ optional bool has_time_signatures = 5;
37
+ optional int32 most_frequent_metric_depth = 6;
38
+ optional int32 tpq = 7;
39
+ }
40
+
41
+ message DownbeatProportion {
42
+ optional string filepath = 1;
43
+ optional int32 track_num = 2;
44
+ optional int32 instrument = 3;
45
+ optional bool is_drum = 4;
46
+ optional float downbeat_proportion = 5;
47
+ }
48
+
49
+ message AlignedMetricDepth {
50
+ optional string filepath = 1;
51
+ optional int32 track_num = 2;
52
+ optional int32 instrument = 3;
53
+ optional bool is_drum = 4;
54
+
55
+ optional int32 aligned_offset = 5;
56
+ }
57
+
58
+ message SimultaneousOnset {
59
+ optional string filepath = 1;
60
+ optional int32 track_num = 2;
61
+ optional int32 instrument = 3;
62
+ optional bool is_drum = 4;
63
+
64
+ optional int32 simultaneous_onset_count = 5;
65
+ }
66
+
67
+ message BeatStability {
68
+ optional string filepath = 1;
69
+ //optional int32 track_num = 2;
70
+ //optional int32 instrument = 3;
71
+ //optional bool is_drum = 4;
72
+
73
+ optional float beat_stability_stdev = 5;
74
+ optional float beat_stability_median = 6;
75
+ }
76
+
77
+ message DrumPresence {
78
+ optional string filepath = 1;
79
+ optional float drum_presence = 2;
80
+ }
81
+
82
+ message Features {
83
+ repeated PitchRange pitch_range = 1;
84
+ repeated MetricDepth metric_depth = 2;
85
+ repeated DownbeatProportion downbeat_proportion = 3;
86
+ repeated AlignedMetricDepth aligned_metric_depth = 4;
87
+ repeated SimultaneousOnset simultaneous_onset = 5;
88
+ repeated BeatStability beat_stability = 6;
89
+ repeated DrumPresence drum_presence = 7;
90
+ repeated MedianMetricDepth median_metric_depth = 8;
91
+ repeated MostFrequentMetricDepth most_frequent_metric_depth = 9;
92
+ }
libraries/protobuf/src/midi.proto ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+
3
+ # Introduction
4
+
5
+ To generate material three protobuf messages must be supplied to the midigpt::sample function. The specification for each of these messages (Piece, Status and SampleParam) is outlined in this document. Any field which starts with internal_ should be ignored, as these are for internal use only. Or if you really know what you are doing ;). There are examples of midi::Piece, midi::Status and midi::SampleParam objects in the docs folder. For working examples, please consult the testing suite, which is found in MMM_API/src/midigpt_api/test/unit.cpp.
6
+
7
+ # Functionality Overview
8
+
9
+ | Functionality | Scope | Description |
10
+ | ----------- | ----------- | ----------- |
11
+ | Velocity | Always Enabled | 32 levels of loudness for individual notes. |
12
+ | Instrument | Track | The General MIDI instrument (i.e. Timbre). |
13
+ | Max Polyphony Hard-Limit | Track | A hard-limit on the number of simultaneously sounding notes. |
14
+ | Note Duration (upper/lower soft bounds) | Non-Drum Tracks | Tells the model what the 15th (lower) and 85th (upper) quantiles of the note duration (i.e. Quarter, Whole) distrbution should be. |
15
+ | Polyphony (upper/lower soft bounds) | Non-Drum Tracks | Tells the model what the 15th (lower) and 85th (upper) quantiles of the polyphony distrbution should be. |
16
+ | Density (10 levels) | Drum Tracks | Tells the model the number of notes per bar to produce |
17
+ | Auto-regressive Sampling Mode | Track | When enabled, bars are always sampled in chronological order. |
18
+ | Time Signature | Bar | A unique time-signature can be specified for each bar. |
19
+ | Temperature | per API call | A higher value increases entropy of generated output. Temperature=1 applies no modification to the probabilities produced by the model. |
20
+ | Context size (model_dim) | per API call | The number of bars that the model can process in one API call |
21
+
22
+
23
+ # Parameter Constraints and Considerations
24
+
25
+ There are two sampling methods: autoregressive generation, where we progressively sample musical material forwards in time on each track; and conditional generation (bar-infilling), where generated material is conditioned on past and future material.
26
+
27
+ Note that a single call the the model midigpt:sample() may involve both autoregressive and conditional generation, as these can be specified on a per-track basis. These constraints are
28
+
29
+ ## Sample Param Constraints
30
+
31
+ 1. tracks_per_step :
32
+ - must be on range [1,number of tracks in piece]
33
+
34
+ 2. bars_per_step :
35
+ - must be on the range [1,model_dim]
36
+ - for conditional generation it is ill-advised for the user to have bars_per_step == model_dim, as this means generation will not be conditioned on any bars
37
+
38
+ 3. shuffle :
39
+ - this only applies in cases where one or more tracks are conditionally generated (i.e. resample = False && 1+ selected_bars = True)
40
+
41
+ 4. percentage :
42
+ - this only applies in cases where one or more tracks are conditionally generated (i.e. resample = False && 1+ selected_bars = True)
43
+
44
+ ## Status Constraints
45
+
46
+ 1. density :
47
+ - this control only applies to drum tracks. This works will both infilling and autoregressive mode.
48
+
49
+ 2. note duration / polyphony :
50
+ - this control only applies to non-drum tracks. This works will both infilling and autoregressive mode.
51
+
52
+ 3. autoregressive :
53
+ - you can only enable autoregressive mode (resample = True) when all the bars are selected in a track.
54
+ - note you may have autoregressive disabled when all bars are selected in a track
55
+
56
+ 4. ignore :
57
+ - bars which have 1+ selected_bars = True may not be ignored, as they are needed to condition the generation
58
+
59
+
60
+ # Protobuf Specification
61
+
62
+ */
63
+ syntax = "proto2";
64
+
65
+ import "track_type.proto";
66
+ import "enum.proto";
67
+ import "midi_internal.proto";
68
+ import "feature_extraction.proto";
69
+
70
+ package midi;
71
+
72
+ /*
73
+ Specify the minimum or maximum amount of polyphony using these values. Using POLYPHONY_ANY lets the model choose the level of polyphony.
74
+ */
75
+ enum PolyphonyLevel {
76
+ POLYPHONY_ANY = 0;
77
+ POLYPHONY_ONE = 1;
78
+ POLYPHONY_TWO = 2;
79
+ POLYPHONY_THREE = 3;
80
+ POLYPHONY_FOUR = 4;
81
+ POLYPHONY_FIVE = 5;
82
+ POLYPHONY_SIX = 6;
83
+ }
84
+
85
+ /*
86
+ Specify the minimum or maximum bounds for note-duration using these values. Using DURATION_ANY lets the model choose the bounds for note duration.
87
+ */
88
+ enum NoteDurationLevel {
89
+ DURATION_ANY = 0;
90
+ DURATION_THIRTY_SECOND = 1;
91
+ DURATION_SIXTEENTH = 2;
92
+ DURATION_EIGHTH = 3;
93
+ DURATION_QUARTER = 4;
94
+ DURATION_HALF = 5;
95
+ DURATION_WHOLE = 6;
96
+ }
97
+
98
+ /*
99
+ Specify the minimum or maximum amount of note density using these values. Using DENSITY ANY lets the model choose the level of density.
100
+ */
101
+ enum DensityLevel {
102
+ DENSITY_ANY = 0;
103
+ DENSITY_ONE = 1;
104
+ DENSITY_TWO = 2;
105
+ DENSITY_THREE = 3;
106
+ DENSITY_FOUR = 4;
107
+ DENSITY_FIVE = 5;
108
+ DENSITY_SIX = 6;
109
+ DENSITY_SEVEN = 7;
110
+ DENSITY_EIGHT = 8;
111
+ DENSITY_NINE = 9;
112
+ DENSITY_TEN = 10;
113
+ }
114
+
115
+ // specify the levels for bar level onset density
116
+ enum BarLevelOnsetDensityLevel {
117
+ BAR_LEVEL_ONSET_DENSITY_ANY = 0;
118
+ BAR_LEVEL_ONSET_DENSITY_ZERO = 1;
119
+ BAR_LEVEL_ONSET_DENSITY_ONE = 2;
120
+ BAR_LEVEL_ONSET_DENSITY_TWO = 3;
121
+ BAR_LEVEL_ONSET_DENSITY_THREE = 4;
122
+ BAR_LEVEL_ONSET_DENSITY_FOUR = 5;
123
+ BAR_LEVEL_ONSET_DENSITY_FIVE = 6;
124
+ BAR_LEVEL_ONSET_DENSITY_SIX = 7;
125
+ BAR_LEVEL_ONSET_DENSITY_SEVEN = 8;
126
+ BAR_LEVEL_ONSET_DENSITY_EIGHT = 9;
127
+ BAR_LEVEL_ONSET_DENSITY_NINE = 10;
128
+ BAR_LEVEL_ONSET_DENSITY_TEN = 11;
129
+ BAR_LEVEL_ONSET_DENSITY_ELEVEN = 12;
130
+ BAR_LEVEL_ONSET_DENSITY_TWELVE = 13;
131
+ BAR_LEVEL_ONSET_DENSITY_THIRTEEN = 14;
132
+ BAR_LEVEL_ONSET_DENSITY_FOURTEEN = 15;
133
+ BAR_LEVEL_ONSET_DENSITY_FIFTEEN = 16;
134
+ BAR_LEVEL_ONSET_DENSITY_SIXTEEN = 17;
135
+ }
136
+
137
+ // specify the levels for bar level onset polyphony
138
+ enum BarLevelOnsetPolyphonyLevel {
139
+ BAR_LEVEL_ONSET_POLYPHONY_ANY = 0;
140
+ BAR_LEVEL_ONSET_POLYPHONY_ONE = 1;
141
+ BAR_LEVEL_ONSET_POLYPHONY_TWO = 2;
142
+ BAR_LEVEL_ONSET_POLYPHONY_THREE = 3;
143
+ BAR_LEVEL_ONSET_POLYPHONY_FOUR = 4;
144
+ BAR_LEVEL_ONSET_POLYPHONY_FIVE = 5;
145
+ BAR_LEVEL_ONSET_POLYPHONY_SIX = 6;
146
+ }
147
+
148
+ enum SilenceProportionLevel {
149
+ SILENCE_PROPORTION_LEVEL_ANY = 0;
150
+ SILENCE_PROPORTION_LEVEL_ONE = 1;
151
+ SILENCE_PROPORTION_LEVEL_TWO = 2;
152
+ SILENCE_PROPORTION_LEVEL_THREE = 3;
153
+ SILENCE_PROPORTION_LEVEL_FOUR = 4;
154
+ SILENCE_PROPORTION_LEVEL_FIVE = 5;
155
+ SILENCE_PROPORTION_LEVEL_SIX = 6;
156
+ SILENCE_PROPORTION_LEVEL_SEVEN = 7;
157
+ SILENCE_PROPORTION_LEVEL_EIGHT = 8;
158
+ SILENCE_PROPORTION_LEVEL_NINE = 9;
159
+ SILENCE_PROPORTION_LEVEL_TEN = 10;
160
+ }
161
+
162
+ enum BooleanLevel {
163
+ BOOLEAN_ANY = 0;
164
+ BOOLEAN_FALSE = 1;
165
+ BOOLEAN_TRUE = 2;
166
+ }
167
+
168
+ /*
169
+ The Event Message is used to represent a MIDI note onset or offset.
170
+ */
171
+ message Event {
172
+ /*
173
+ The time of the event (either a note onset or note offset) relative to the current bar in quantized steps. Currently, most model quantize each quarter note beat into 12 subdivisions. As a result, if the event happens an eighth note after the start of the bar, this value would be 6. If the event occurs three quarter notes after the start of the bar, this value would be 3 * 12 = 36.
174
+ */
175
+ optional int32 time = 1 [(minval) = 0, (maxval) = 1000000];
176
+ /*
177
+ The MIDI velocity. This value must be 0 for note off messages.
178
+ */
179
+ optional int32 velocity = 2 [(minval) = 0, (maxval) = 127];
180
+ /*
181
+ The MIDI pitch value of on the range [0,128).
182
+ */
183
+ optional int32 pitch = 3 [(minval) = 0, (maxval) = 127];
184
+
185
+ optional int32 internal_instrument = 4;
186
+ optional int32 internal_track_type = 10;
187
+ optional int32 internal_duration = 11;
188
+ optional int32 delta = 12;
189
+ }
190
+
191
+ /*
192
+ The Bar message specifies the events occuring in a bar.
193
+ */
194
+ message Bar {
195
+ /*
196
+ A list of integers, which are simply the indices of the messages found in the Piece.events repeated message. Note offsets which occur at the end of the bar (i.e. event.time = 48 with a time signature of 4/4 and piece.resolution of 12) should be included in the current bar rather than the next bar. In other words, no note offsets should even have an event.time = 0, as these note offset events would belong in the previous bar.
197
+ */
198
+ repeated int32 events = 1 [(minval) = 0, (maxval) = 2147483647];
199
+ /*
200
+ Numerator for the time-signature of the bar. Note that while time signatures can vary from bar to bar, they cannot vary from track to track. In other words if the second bar in track 0 has a time signature of 4/4, the second bar in track 1 must also have a time signature of 4/4.
201
+ */
202
+ optional int32 ts_numerator = 8 [(minval) = 1, (maxval) = 1000000];
203
+ /*
204
+ Denominator for the time-signature of the bar.
205
+ */
206
+ optional int32 ts_denominator = 9 [(minval) = 1, (maxval) = 1000000];
207
+
208
+ optional float internal_beat_length = 5;
209
+ optional bool internal_has_notes = 3;
210
+ repeated ContinuousFeature internal_feature = 10;
211
+ repeated BarFeatures internal_features = 11; // why isn't this just called bar features
212
+ }
213
+
214
+ /*
215
+ The piece message contains a list of bars, and specifies the instrument and track_type.
216
+ */
217
+ message Track {
218
+ /*
219
+ A list of bars. Note that each track must have the same number of bars.
220
+ */
221
+ repeated Bar bars = 1;
222
+ /*
223
+ The MIDI instrument number for the track.
224
+ */
225
+ optional int32 instrument = 3 [(minval) = 0, (maxval) = 139];
226
+ /*127 original instruments with drum instrument seperated into 12 individual drum intruments (TR-808) with the rest of the drums in a single intrument*/
227
+ /*
228
+ This must be a value in the TRACK_TYPE enum. In most cases, using STANDARAD_TRACK and STANDARD_DRUM_TRACK will suffice to denote a non-drum instrument track and a drum track respectively.
229
+ */
230
+ optional TRACK_TYPE track_type = 5;
231
+
232
+ repeated TRACK_TYPE internal_train_types = 6;
233
+ repeated TrackFeatures internal_features = 7; // why isn't this called track features
234
+ }
235
+
236
+ /*
237
+ The Piece message specifies the actual musical material in a track-separated event-based format, specifying the note onsets and offsets for each bar in each track.
238
+ */
239
+ message Piece {
240
+ /*
241
+ Organizes MIDI events into tracks and bars. In short, each track contains a list of bars, which in turn contains a list of event indices (corresponding to the repeated events message in the Piece.
242
+ */
243
+ repeated Track tracks = 1;
244
+ /*
245
+ A list of MIDI events which the tracks and bars reference
246
+ */
247
+ repeated Event events = 2;
248
+ /*
249
+ The time resolution used to quantize / discretize musical material. Unless otherwise instructed, this should be set to 12.
250
+ */
251
+ optional int32 resolution = 3 [(minval) = 1, (maxval) = 24];
252
+ /*
253
+ Optionally the tempo can be specified. However this is not taken into consideration by the model.
254
+ */
255
+ //optional int32 tempo = 4 [(minval) = 1, (maxval) = 1000000];
256
+ optional int32 tempo = 4;
257
+
258
+ repeated int32 internal_valid_segments = 7;
259
+ repeated uint32 internal_valid_tracks = 8;
260
+ optional int32 internal_segment_length = 12;
261
+ repeated ValidTrack internal_valid_tracks_v2 = 13;
262
+ repeated GenreData internal_genre_data = 14;
263
+ optional MetadataLabels internal_metadata_labels = 15;
264
+ repeated PieceFeatures internal_features = 16;
265
+
266
+ optional int32 internal_ticks_per_quarter = 5;
267
+ optional bool internal_has_time_signatures = 6;
268
+ }
269
+
270
+ /*
271
+ The StatusBar message specifies per-bar information for generation.
272
+ */
273
+ message StatusBar {
274
+ optional int32 ts_numerator = 1;
275
+ optional int32 ts_denominator = 2;
276
+ optional BarLevelOnsetDensityLevel onset_density = 3;
277
+ optional BarLevelOnsetPolyphonyLevel onset_polyphony_min = 4;
278
+ optional BarLevelOnsetPolyphonyLevel onset_polyphony_max = 5;
279
+ }
280
+
281
+ /*
282
+ The StatusTrack message specifies per-track information for generation.
283
+ */
284
+ message StatusTrack {
285
+ /*
286
+ The index of a track in the Piece message. For a track to be seen by the model, it must be referenced by a StatusTrack message via the track_id field. Tracks that are not referenced by a StatusTrack message will not be considered by the model.
287
+ */
288
+ optional int32 track_id = 1 [(minval) = 0, (maxval) = 1000000];
289
+ /*
290
+ This must be a value in the TRACK_TYPE enum. This should be equivalent to the TRACK_TYPE specified in the corresponding Track, unless you are giving the model an option to choose either a drum or instrument track. In this case use the STANDARAD_BOTH value here, and the TRACK_TYPE in the piece will be ignored.
291
+ */
292
+ optional TRACK_TYPE track_type = 2;
293
+ /*
294
+ This must be a value in the GM_TYPE enum. It specifies the set of possible instruments that the model may choose from. The mapping between GM_TYPE and instrument numbers can be found in src/midigpt_api/enum/gm.h. For example, using midi::GM_TYPE::piano will allow the model to use any piano instrument.
295
+ */
296
+ optional GM_TYPE instrument = 3;
297
+ /*
298
+ A list of boolean values which specifies whether a bar is to be generated (true) or conditioned on (false). This must be the same length as the number of bars in the corresponding Track message.
299
+ */
300
+ repeated bool selected_bars = 5;
301
+ /*
302
+ Indicates whether or not to use auto-regressive sampling. Note that you can only use auto-regressive sampling when each value in selected_bars is true (i.e. the entire track is being generated). Note that you do not have to use auto-regressive sampling when all selected bars is all true.
303
+ */
304
+ optional bool autoregressive = 6;
305
+ /*
306
+ This indicates that the track should be ignored. The model will not be conditioned on this track, and it will in no way effect the generated outcome.
307
+ */
308
+ optional bool ignore = 7;
309
+
310
+ optional DensityLevel density = 4;
311
+ optional PolyphonyLevel min_polyphony_q = 10;
312
+ optional PolyphonyLevel max_polyphony_q = 11;
313
+ optional NoteDurationLevel min_note_duration_q = 12;
314
+ optional NoteDurationLevel max_note_duration_q = 13;
315
+
316
+ optional BarLevelOnsetPolyphonyLevel onset_polyphony_min = 20;
317
+ optional BarLevelOnsetPolyphonyLevel onset_polyphony_max = 21;
318
+ optional BarLevelOnsetDensityLevel onset_density = 22;
319
+ optional BarLevelOnsetDensityLevel onset_density_min = 23;
320
+ optional BarLevelOnsetDensityLevel onset_density_max = 24;
321
+ optional int32 min_pitch = 25 [(minval) = 0, (maxval) = 127];
322
+ optional int32 max_pitch = 26 [(minval) = 0, (maxval) = 127];
323
+ optional GenreMusicmap genre = 28;
324
+ optional SilenceProportionLevel silence_proportion_min = 29;
325
+ optional SilenceProportionLevel silence_proportion_max = 30;
326
+ optional DensityLevel note_density_level = 31;
327
+
328
+ optional BooleanLevel contains_note_duration_thirty_second = 36;
329
+ optional BooleanLevel contains_note_duration_sixteenth = 37;
330
+ optional BooleanLevel contains_note_duration_eighth = 38;
331
+ optional BooleanLevel contains_note_duration_quarter = 39;
332
+ optional BooleanLevel contains_note_duration_half = 40;
333
+ optional BooleanLevel contains_note_duration_whole = 41;
334
+
335
+ /*
336
+ Sets a hard limit on the polyphony on this track. This is implemented by keeping a record of all the currently sounding notes, and preventing the model from generating note-onset tokens when the limit is reached.
337
+ */
338
+ optional int32 polyphony_hard_limit = 16 [(minval) = 0, (maxval) = 100];
339
+ /*
340
+ Allows for the entropy of generation to be adjusted. When temperature=1, the probability distributions output by the model are unaltered. When temperature<1 the probability distribution is increasingly biased towards the most probable tokens. With a very small temperature value this would be equivalent to argmax sampling. When temperature>1 the probability distribution moves towards a random uniform distribution. It is recommended to keep this value close to 1 in most cases.
341
+ */
342
+ optional float temperature = 17 [(fminval) = 0.5, (fmaxval) = 2.0];
343
+
344
+
345
+ repeated int32 internal_ts_numerators = 14;
346
+ repeated int32 internal_ts_denominators = 15;
347
+ optional string internal_genre = 9;
348
+
349
+ repeated StatusBar bars = 19;
350
+ }
351
+
352
+ /*
353
+ The Status message specifies which bars or tracks are to be generated/conditioned on, and provides extra information about conditioning such as instrument, density, polyphony and note-duration.
354
+ */
355
+ message Status {
356
+ repeated StatusTrack tracks = 1;
357
+
358
+ /*
359
+ For microtiming generation, last sampling step must be decoded using the delta resolution
360
+ */
361
+ optional bool decode_final = 2;
362
+ optional bool full_resolution = 3;
363
+ }
364
+
365
+ /*
366
+ The SampleParam message specifies hyper-parameters for generation.
367
+ */
368
+ message HyperParam {
369
+ /*
370
+ For multi-step generation (typically employed when the entire piece is too large to be considered by the model simultaneously) this parameter specifies the number of tracks that are generated in each step.
371
+ */
372
+ optional int32 tracks_per_step = 1 [(minval) = 1, (maxval) = 12];
373
+ /*
374
+ For multi-step generation this parameter specifies the number of bars that are generated in each step. This value should be set in relation to model_dim. If bars_per_step = model_dim, then there will be no horizontal conditioning, which will typically produce inferior results. A good rule of thumb is to use bars_per_step == model_dim / 2.
375
+ */
376
+ optional int32 bars_per_step = 2 [(minval) = 1, (maxval) = 8];
377
+ /*
378
+ The size of the model. In most cases this will be 4.
379
+ */
380
+ optional int32 model_dim = 3 [(minval) = 1, (maxval) = 8];
381
+ /*
382
+ The percentage of the selected material (selected bars in the Status message) that will be generated.
383
+ */
384
+ optional int32 percentage = 5 [(minval) = 1, (maxval) = 100];
385
+ /*
386
+ The number of outputs to be generated. Currently we only support batch_size=1.
387
+ With multi-step sampling its is likely more efficient to simply make several calls in series.
388
+ */
389
+ optional int32 batch_size = 7 [(minval) = 1, (maxval) = 1];
390
+ /*
391
+ Allows for the entropy of generation to be adjusted. When temperature=1, the probability distributions output by the model are unaltered. When temperature<1 the probability distribution is increasingly biased towards the most probable tokens. With a very small temperature value this would be equivalent to argmax sampling. When temperature>1 the probability distribution moves towards a random uniform distribution. It is recommended to keep this value close to 1 in most cases.
392
+ */
393
+ optional float temperature = 6 [(fminval) = 0.5, (fmaxval) = 2.0];
394
+ /*
395
+ This parameter turns on and off per-track temperature control
396
+ */
397
+ optional bool use_per_track_temperature = 17;
398
+ /*
399
+ The max number of tokens to generate before terminating generation. Can be used to avoid memory overload. When this value is set to zero it is ignored, and no limitations are set of the number of generated tokens.
400
+ */
401
+ optional int32 max_steps = 13 [(minval) = 0, (maxval) = 2048];
402
+
403
+ optional int32 polyphony_hard_limit = 14 [(minval) = 0, (maxval) = 100];
404
+ /*
405
+ When shuffle=true the generation steps are randomly ordered. For obvious reasons, auto-regressive sampling cannot be used with shuffle=true, as it would cease to be auto-regressive.
406
+ */
407
+ optional bool shuffle = 4;
408
+ /*
409
+ Mainly for debugging purposes.
410
+ */
411
+ optional bool verbose = 8;
412
+ /*
413
+ The path to the ckpt, which should either be an absolute path or relative to the executable.
414
+ */
415
+ optional string ckpt = 9;
416
+ /*
417
+ Control over probability of masking top k
418
+ */
419
+ optional float mask_top_k = 10 [(fminval) = 0, (fmaxval) = 1.];
420
+ /*
421
+ Control stochastic seed for reproducability
422
+ */
423
+ optional int32 sampling_seed = 11;
424
+
425
+ optional bool internal_skip_preprocess = 12;
426
+ optional bool internal_disable_masking = 16;
427
+
428
+ }
429
+
libraries/protobuf/src/midi_internal.proto ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto2";
2
+
3
+ import "enum.proto";
4
+ import "google/protobuf/descriptor.proto";
5
+
6
+ package midi;
7
+
8
+ extend google.protobuf.FieldOptions {
9
+ optional int32 maxval = 50001;
10
+ optional int32 minval = 50002;
11
+ optional float fmaxval = 50003;
12
+ optional float fminval = 50004;
13
+ }
14
+
15
+ message ContinuousFeature {
16
+ optional float av_polyphony = 1;
17
+ optional float note_duration = 3;
18
+ optional float note_duration_norm = 4;
19
+ }
20
+
21
+ message BarFeatures {
22
+ optional int32 onset_density = 1;
23
+ optional int32 onset_polyphony_min = 2;
24
+ optional int32 onset_polyphony_max = 3;
25
+ }
26
+
27
+ message TrackLevelAttributeControlDistributions {
28
+ repeated int32 polyphony_quantile = 1;
29
+ repeated int32 note_duration_quantile = 2;
30
+ repeated int32 note_density = 3;
31
+ repeated int32 onset_polyphony = 4;
32
+ repeated int32 onset_density = 5;
33
+ repeated int32 note_duration = 8;
34
+ }
35
+
36
+ message TrackFeatures {
37
+ optional int32 min_pitch = 1;
38
+ optional int32 max_pitch = 2;
39
+ optional float av_polyphony = 3;
40
+ optional int32 note_density = 4;
41
+ optional int32 note_density_v2 = 5;
42
+ optional int32 max_polyphony = 6;
43
+ optional bool should_prune = 7;
44
+ optional int32 order = 8;
45
+ optional float note_duration = 9;
46
+ optional string genre_str = 10;
47
+ optional int32 min_polyphony_q = 11;
48
+ optional int32 max_polyphony_q = 12;
49
+ optional int32 min_note_duration_q = 13;
50
+ optional int32 max_note_duration_q = 14;
51
+ repeated int32 polyphony_distribution = 15;
52
+ optional float note_density_value = 16;
53
+
54
+ optional int32 min_polyphony_hard = 18;
55
+ optional int32 max_polyphony_hard = 19;
56
+ optional int32 min_note_duration_hard = 20;
57
+ optional int32 max_note_duration_hard = 21;
58
+
59
+ optional int32 onset_polyphony_min = 24;
60
+ optional int32 onset_polyphony_max = 25;
61
+ optional int32 onset_density = 26;
62
+ optional int32 onset_density_min = 27;
63
+ optional int32 onset_density_max = 28;
64
+ repeated int32 duration_distribution = 29;
65
+
66
+ optional int32 genre = 32;
67
+ optional int32 note_density_level = 35;
68
+
69
+ optional int32 contains_note_duration_thirty_second = 40;
70
+ optional int32 contains_note_duration_sixteenth = 41;
71
+ optional int32 contains_note_duration_eighth = 42;
72
+ optional int32 contains_note_duration_quarter = 43;
73
+ optional int32 contains_note_duration_half = 44;
74
+ optional int32 contains_note_duration_whole = 45;
75
+
76
+ optional TrackLevelAttributeControlDistributions attribute_control_distributions = 30; // store them all here
77
+
78
+ }
79
+
80
+ message PieceFeatures {
81
+ optional string genre = 1;
82
+ }
83
+
84
+ message Note {
85
+ optional int32 start = 1;
86
+ optional int32 end = 2;
87
+ optional int32 pitch = 3;
88
+ optional int32 tick_delta = 4;
89
+ }
90
+
91
+ message ValidTrack {
92
+ repeated int32 tracks = 1;
93
+ }
94
+
95
+ message Item {
96
+ optional uint64 start = 1;
97
+ optional uint64 end = 2;
98
+ optional uint64 src_size = 3;
99
+ }
100
+
101
+ message Dataset {
102
+ repeated Item train = 1;
103
+ repeated Item valid = 2;
104
+ repeated Item test = 3;
105
+ }
106
+
107
+ message ModelMetadata {
108
+ optional string encoder = 1;
109
+ optional int32 num_layers = 2;
110
+ optional int32 num_heads = 3;
111
+ optional int32 num_hidden = 4;
112
+ optional int32 model_dim = 5;
113
+ optional bool new_state = 6;
114
+ }
115
+
116
+ message GenreData {
117
+ optional string discogs = 1;
118
+ optional string lastfm = 2;
119
+ optional string tagtraum = 3;
120
+ }
121
+
122
+ message MetadataLabels {
123
+ optional GenreMusicmap genre = 1;
124
+ optional int32 nomml = 6;
125
+ }
libraries/protobuf/src/track_type.proto ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto2";
2
+
3
+ package midi;
4
+
5
+ enum TRACK_TYPE {
6
+ AUX_DRUM_TRACK = 8;
7
+ AUX_INST_TRACK = 9;
8
+ STANDARD_TRACK = 10;
9
+ STANDARD_DRUM_TRACK = 11;
10
+ STANDARD_BOTH = 12;
11
+ NUM_TRACK_TYPES = 16;
12
+ }
libraries/pybind11 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 5ccb9e412d8974e8eeac1b061d9077ac0bd365e1
libraries/torch/CMakeLists.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.8)
2
+
3
+ project(midigpt_torch)
4
+
5
+ set(SRCS
6
+ src/torch_library.cpp
7
+ "include/torch_library.h")
8
+
9
+ set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../libtorch/")
10
+ find_package(Torch REQUIRED)
11
+
12
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
13
+
14
+ add_library(midigpt_torch
15
+ ${SRCS})
16
+
17
+ target_link_libraries(midigpt_torch PRIVATE "${TORCH_LIBRARIES}")
18
+
19
+ if (MSVC)
20
+ file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
21
+ add_custom_command(TARGET example-app
22
+ POST_BUILD
23
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different
24
+ ${TORCH_DLLS}
25
+ $<TARGET_FILE_DIR:example-app>)
26
+ endif (MSVC)
libraries/torch/CMakeSettings.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "configurations": [
3
+ {
4
+ "name": "x64-Debug",
5
+ "generator": "Ninja",
6
+ "configurationType": "Debug",
7
+ "inheritEnvironments": [ "msvc_x64_x64" ],
8
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
9
+ "installRoot": "${projectDir}\\out\\install\\${name}",
10
+ "cmakeCommandArgs": "",
11
+ "buildCommandArgs": "",
12
+ "ctestCommandArgs": ""
13
+ },
14
+ {
15
+ "name": "WSL-GCC-Debug",
16
+ "generator": "Ninja",
17
+ "configurationType": "Debug",
18
+ "buildRoot": "${projectDir}\\out\\build\\${name}",
19
+ "installRoot": "${projectDir}\\out\\install\\${name}",
20
+ "cmakeExecutable": "cmake",
21
+ "cmakeCommandArgs": "",
22
+ "buildCommandArgs": "",
23
+ "ctestCommandArgs": "",
24
+ "inheritEnvironments": [ "linux_x64" ],
25
+ "wslPath": "${defaultWSLPath}",
26
+ "variables": []
27
+ }
28
+ ]
29
+ }
libraries/torch/include/torch_library.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace midigptTorch {
4
+ void testMmmTorch();
5
+ }
libraries/torch/src/torch_library.cpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include "../../libtorch/include/torch/csrc/api/include/torch/torch.h"
3
+
4
+ namespace midigptTorch {
5
+ void testMmmTorch() {
6
+ std::cout << "midigptTorch test function called" << std::endl;
7
+ torch::Tensor tensor = torch::rand({ 2, 3 });
8
+ std::cout << tensor << std::endl;
9
+ }
10
+ }
midigpt_setup_helper.sh ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ clone=no
4
+ replace=no
5
+ test_train=no
6
+ inference=no
7
+ mac=no
8
+ python_path=/Library/Frameworks/Python.framework/Versions/3.8/bin
9
+ og=no
10
+
11
+ usage="
12
+ $(basename "$0") [-h] [-n] [-c] [-i] [-m] [-r] [-d
13
+ directory]
14
+ -- Script for setting up and testing the MIDI-GPT repository
15
+
16
+ where:
17
+ -h Show this help text
18
+ -n Test the training script imports
19
+ -c Clone the github MIDI-GPT repository
20
+ -i If you wish to setup repository for inference
21
+ -d Provide directory name where repo is/will be cloned
22
+ -r Replace directory if already exists
23
+ -m If on MacOS CPU
24
+ "
25
+
26
+
27
+ OPTSTRING=":cnhiomrd:k:"
28
+
29
+ while getopts ${OPTSTRING} opt; do
30
+ case ${opt} in
31
+ h)
32
+ echo "${usage}"
33
+ exit 0
34
+ ;;
35
+ i)
36
+ inference=yes
37
+ ;;
38
+ n)
39
+ test_train=yes
40
+ ;;
41
+ m)
42
+ mac=yes
43
+ ;;
44
+ r)
45
+ replace=yes
46
+ ;;
47
+ c)
48
+ clone=yes
49
+ ;;
50
+ d)
51
+ repo=${OPTARG}
52
+ ;;
53
+ :)
54
+ echo "Option -${OPTARG} requires an argument"
55
+ exit 1
56
+ ;;
57
+ ?)
58
+ echo "Invalid option: -${OPTARG}"
59
+ exit 1
60
+ ;;
61
+ esac
62
+ done
63
+
64
+ if test "${clone}" = "yes"
65
+ then
66
+ echo "Cloning MIDI-GPT"
67
+ fi
68
+
69
+ echo "In directory: ${repo}"
70
+
71
+ if test "${replace}" = "yes"
72
+ then
73
+ if [[ -d ${repo} ]]
74
+ then
75
+ echo "Directory ${repo} already exists, removing it"
76
+ rm -rf ${repo}
77
+ fi
78
+ fi
79
+
80
+ mkdir -p ${repo}
81
+ cd ${repo}
82
+
83
+ echo "Loading modules"
84
+
85
+ if test "${clone}" = "yes"
86
+ then
87
+ if [[ -d MIDI-GPT ]] || [[ -d ENV ]]
88
+ then
89
+ echo "MIDI-GPT or ENV directories already exist"
90
+ exit 1
91
+ fi
92
+ if test "${og}" = "yes"
93
+ then
94
+ {
95
+ git clone https://www.github.com/Metacreation-Lab/MIDI-GPT.git
96
+
97
+ } || {
98
+ echo "Cloning failed"
99
+ exit 1
100
+ }
101
+ else
102
+ {
103
+ git clone https://www.github.com/Metacreation-Lab/MIDI-GPT.git
104
+
105
+ } || {
106
+ echo "Cloning failed"
107
+ exit 1
108
+ }
109
+ fi
110
+
111
+ ${python_path}/virtualenv --no-download ./ENV
112
+
113
+ else
114
+ if ! [[ -d MIDI-GPT ]]
115
+ then
116
+ echo "MIDI-GPT doesn't exist, try cloning the repository with the -c option"
117
+ exit 1
118
+ fi
119
+ fi
120
+
121
+ {
122
+ source ./ENV/bin/activate
123
+ } || {
124
+ echo "ENV virtual environment doesn't exist"
125
+ exit 1
126
+ }
127
+
128
+ echo "pip installs"
129
+
130
+ pip install --no-index --upgrade pip
131
+ pip install torch==1.13.0
132
+ pip install transformers==4.26.1
133
+
134
+ cd MIDI-GPT
135
+ if test "${og}" = "no"
136
+ then
137
+ git checkout main
138
+ fi
139
+
140
+ echo "Starting python library build"
141
+
142
+ { if test "${inference}" = "yes"
143
+ then
144
+ echo "Building for inference"
145
+ if test "${mac}" = "yes"
146
+ then
147
+ echo "On MacOS CPU"
148
+ bash create_python_library.sh --mac_os
149
+ else
150
+ echo "On Compute Canada"
151
+ bash create_python_library.sh --compute_canada
152
+ fi
153
+ else
154
+ echo "Building for training only"
155
+ bash create_python_library.sh --test_build --compute_canada --no_torch
156
+ fi
157
+ } || {
158
+ echo "Build failed"
159
+ exit 1
160
+ }
161
+
162
+ if test "${test_train}" = "yes"
163
+ then
164
+
165
+ cd ../
166
+
167
+ deactivate
168
+
169
+ echo "Activating environment"
170
+
171
+ source $PWD/venv/bin/activate
172
+ cd $PWD/MIDI-GPT/python_scripts
173
+
174
+ echo "Testing training script"
175
+
176
+ python3 -c "import train"
177
+
178
+ echo "Import tests done"
179
+
180
+ fi
181
+
182
+ echo "Finished"
models/model.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2370946e0b45c440972ef3099be17f27cd7157b7a40ce9fd24851f2c855e4d85
3
+ size 77523617
pip_requirements/common_requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pretty_midi
pip_requirements/create_dataset_requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ attrs==23.1.0+computecanada
2
+ jsonlines==3.1.0+computecanada
3
+ numpy==1.24.2+computecanada
4
+ tqdm==4.46.1
pip_requirements/inference_requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch==2.0.0
pip_requirements/train_requirements.txt ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anyio==3.6.2+computecanada
2
+ arff==0.9+computecanada
3
+ argon2-cffi==21.3.0+computecanada
4
+ argon2-cffi-bindings==21.2.0+computecanada
5
+ asttokens==2.2.1+computecanada
6
+ async-generator==1.10+computecanada
7
+ attrs==23.1.0+computecanada
8
+ backcall==0.2.0+computecanada
9
+ backports-abc==0.5+computecanada
10
+ backports-shutil-get-terminal-size==1.0.0+computecanada
11
+ bcrypt==4.0.1+computecanada
12
+ beautifulsoup4==4.11.2+computecanada
13
+ bitstring==4.0.1+computecanada
14
+ bleach==6.0.0+computecanada
15
+ certifi==2022.12.7+computecanada
16
+ cffi==1.15.1+computecanada
17
+ chardet==5.1.0+computecanada
18
+ charset-normalizer==3.0.1+computecanada
19
+ comm==0.1.2+computecanada
20
+ contourpy==1.0.7+computecanada
21
+ cryptography==39.0.1+computecanada
22
+ cycler==0.11.0+computecanada
23
+ Cython==0.29.33+computecanada
24
+ deap==1.3.3+computecanada
25
+ debugpy==1.6.6+computecanada
26
+ decorator==5.1.1+computecanada
27
+ defusedxml==0.7.1+computecanada
28
+ dnspython==2.3.0+computecanada
29
+ ecdsa==0.18.0+computecanada
30
+ entrypoints==0.4+computecanada
31
+ executing==1.2.0+computecanada
32
+ fastjsonschema==2.16.2+computecanada
33
+ filelock==3.12.2
34
+ fonttools==4.38.0+computecanada
35
+ fsspec==2023.6.0
36
+ funcsigs==1.0.2+computecanada
37
+ huggingface-hub==0.15.1
38
+ idna==3.4+computecanada
39
+ importlib-metadata==5.2.0+computecanada
40
+ importlib-resources==5.12.0+computecanada
41
+ ipykernel==6.21.2+computecanada
42
+ ipython==8.10.0+computecanada
43
+ ipython-genutils==0.2.0+computecanada
44
+ ipywidgets==8.0.4+computecanada
45
+ jedi==0.18.2+computecanada
46
+ Jinja2==3.1.2+computecanada
47
+ jsonlines==3.1.0+computecanada
48
+ jsonschema==4.17.3+computecanada
49
+ jupyter-client==8.0.3+computecanada
50
+ jupyter-core==5.2.0+computecanada
51
+ jupyter-events==0.6.3+computecanada
52
+ jupyter-server==2.3.0+computecanada
53
+ jupyter-server-terminals==0.4.4+computecanada
54
+ jupyterlab-pygments==0.2.2+computecanada
55
+ jupyterlab-widgets==3.0.5+computecanada
56
+ kiwisolver==1.4.4+computecanada
57
+ lockfile==0.12.2+computecanada
58
+ MarkupSafe==2.1.2+computecanada
59
+ matplotlib==3.7.0+computecanada
60
+ matplotlib-inline==0.1.6+computecanada
61
+ mistune==2.0.5+computecanada
62
+ mock==5.0.1+computecanada
63
+ mpmath==1.2.1+computecanada
64
+ nbclassic==0.5.2+computecanada
65
+ nbclient==0.7.2+computecanada
66
+ nbconvert==7.2.9+computecanada
67
+ nbformat==5.7.3+computecanada
68
+ nest-asyncio==1.5.6+computecanada
69
+ netaddr==0.8.0+computecanada
70
+ netifaces==0.11.0+computecanada
71
+ nose==1.3.7+computecanada
72
+ notebook==6.5.2+computecanada
73
+ notebook-shim==0.2.2+computecanada
74
+ numpy==1.24.2+computecanada
75
+ packaging==23.0+computecanada
76
+ pandas==1.5.3+computecanada
77
+ pandocfilters==1.5.0+computecanada
78
+ paramiko==3.0.0+computecanada
79
+ parso==0.8.3+computecanada
80
+ path==16.6.0+computecanada
81
+ path.py==12.5.0+computecanada
82
+ pathlib2==2.3.7+computecanada
83
+ paycheck==1.0.2+computecanada
84
+ pbr==5.11.1+computecanada
85
+ pexpect==4.8.0+computecanada
86
+ pickleshare==0.7.5+computecanada
87
+ Pillow==9.4.0+computecanada
88
+ pkgutil-resolve-name==1.3.10+computecanada
89
+ platformdirs==2.5.2+computecanada
90
+ prometheus-client==0.16.0+computecanada
91
+ prompt-toolkit==3.0.37+computecanada
92
+ protobuf==4.23.3
93
+ psutil==5.9.4+computecanada
94
+ ptyprocess==0.7.0+computecanada
95
+ pure-eval==0.2.2+computecanada
96
+ pycparser==2.21+computecanada
97
+ Pygments==2.14.0+computecanada
98
+ PyNaCl==1.5.0+computecanada
99
+ pyparsing==3.0.9+computecanada
100
+ pyrsistent==0.19.3+computecanada
101
+ python-dateutil==2.8.2+computecanada
102
+ python-json-logger==2.0.7+computecanada
103
+ pytz==2022.7.1+computecanada
104
+ PyYAML==6.0+computecanada
105
+ pyzmq==25.0.0+computecanada
106
+ regex==2022.10.31+computecanada
107
+ requests==2.28.2+computecanada
108
+ rfc3339-validator==0.1.4+computecanada
109
+ rfc3986-validator==0.1.1+computecanada
110
+ scipy==1.10.1+computecanada
111
+ send2trash==1.8.0+computecanada
112
+ simplegeneric==0.8.1+computecanada
113
+ singledispatch==4.0.0+computecanada
114
+ six==1.16.0+computecanada
115
+ sniffio==1.3.0+computecanada
116
+ soupsieve==2.4+computecanada
117
+ stack-data==0.6.2+computecanada
118
+ sympy==1.11.1+computecanada
119
+ tensorboardX==2.2+computecanada
120
+ terminado==0.17.1+computecanada
121
+ testpath==0.6.0+computecanada
122
+ tinycss2==1.2.1+computecanada
123
+ tokenizers==0.13.2+computecanada
124
+ torch==1.13.1+computecanada
125
+ tornado==6.2+computecanada
126
+ tqdm==4.46.1
127
+ traitlets==5.9.0+computecanada
128
+ transformers==4.26.1+computecanada
129
+ typing-extensions==4.5.0+computecanada
130
+ urllib3==1.26.14+computecanada
131
+ wcwidth==0.2.6+computecanada
132
+ webencodings==0.5.1+computecanada
133
+ websocket-client==1.5.1+computecanada
134
+ widgetsnbextension==4.0.5+computecanada
135
+ zipp==3.14.0+computecanada
python_scripts/config/bert.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "hidden_size" : 512,
3
+ "num_hidden_layers" : 6,
4
+ "nuim_attention_heads" : 8,
5
+ "intermediate_size" : 2048,
6
+ "max_position_embeddings" : 2048
7
+ }
python_scripts/config/bert_tiny.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "hidden_size" : 128,
3
+ "num_hidden_layers" : 2,
4
+ "num_attention_heads" : 2,
5
+ "intermediate_size" : 128,
6
+ "max_position_embeddings" : 2048
7
+ }
python_scripts/config/gpt2.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_positions": 2048,
3
+ "n_ctx": 2048,
4
+ "n_layer": 6,
5
+ "n_head": 8,
6
+ "n_embd": 512
7
+ }
python_scripts/config/gpt2_tiny.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_positions": 2048,
3
+ "n_ctx": 2048,
4
+ "n_layer": 2,
5
+ "n_head": 2,
6
+ "n_embd": 128
7
+ }
python_scripts/convert.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # take a trained pytorch model and convert it
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
5
+ print( os.path.dirname(os.getcwd()) + "/python_lib" )
6
+ import midigpt
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import torch
11
+ import torch.quantization
12
+ from transformers import GPT2LMHeadModel, GPT2Config
13
+ from transformers.modeling_utils import Conv1D
14
+
15
+ from custom_models import *
16
+
17
+ from torch import nn
18
+
19
+ class QuantWrapper(nn.Module):
20
+ def __init__(self, module):
21
+ super(QuantWrapper, self).__init__()
22
+ qconfig = module.qconfig if hasattr(module, 'qconfig') else None
23
+ self.add_module('quant', torch.quantization.QuantStub(qconfig))
24
+ self.add_module('dequant', torch.quantization.DeQuantStub())
25
+ self.add_module('module', module)
26
+ self.train(module.training)
27
+
28
+ def forward(self, X, P):
29
+ X = self.quant(X)
30
+ P = self.quant(P)
31
+ O = self.module(X,P)
32
+ return self.dequant(O)
33
+
34
+ def _conv1d_to_linear(module):
35
+ in_size, out_size = module.weight.shape
36
+ linear = torch.nn.Linear(in_size, out_size)
37
+ linear.weight.data = module.weight.data.T.contiguous()
38
+ linear.bias.data = module.bias.data
39
+ return linear
40
+
41
+ def conv1d_to_linear(model):
42
+ for name in list(model._modules):
43
+ module = model._modules[name]
44
+ if isinstance(module, Conv1D):
45
+ linear = _conv1d_to_linear(module)
46
+ model._modules[name] = linear
47
+ else:
48
+ conv1d_to_linear(module)
49
+
50
+ def score_model(model):
51
+ targets = np.load("target.npz")["data"]
52
+
53
+
54
+ def time_model(model):
55
+ start = time.time()
56
+ pkv = None
57
+ for _ in range(1000):
58
+ input_ids = torch.ones(1,1).type(torch.LongTensor)
59
+ outputs = model(input_ids, past_key_values=pkv)
60
+ pkv = outputs[1]
61
+ print("BATCH TIME : {}".format(time.time() - start))
62
+
63
+ def print_size_of_model(model):
64
+ import os
65
+ torch.save(model.state_dict(), "temp.p")
66
+ print('Size (MB):', os.path.getsize("temp.p")/1e6)
67
+ os.remove('temp.p')
68
+
69
+ def quantize_model(model):
70
+ conv1d_to_linear(model)
71
+ model = torch.quantization.quantize_dynamic(
72
+ model, {torch.nn.Linear}, dtype=torch.qint8)
73
+ return model
74
+
75
+ def static_quantize_model(model):
76
+ conv1d_to_linear(model)
77
+ model.qconfig = torch.quantization.default_qconfig
78
+ torch.quantization.prepare(model, inplace=True)
79
+ torch.quantization.convert(model, inplace=True)
80
+
81
+ return model
82
+
83
+ def prune_model(model):
84
+ import torch.nn.utils.prune as prune
85
+
86
+ conv1d_to_linear(model)
87
+ parameters_to_prune = []
88
+ for _,module in model.named_modules():
89
+ if isinstance(module, torch.nn.Linear):
90
+ prune.l1_unstructured(module, name="weight", amount=.8)
91
+ prune.remove(module, "weight")
92
+
93
+ for _,submodule in module.named_modules():
94
+ if isinstance(submodule, torch.nn.Linear):
95
+ prune.l1_unstructured(submodule, name="weight", amount=.8)
96
+ prune.remove(submodule, "weight")
97
+
98
+ return model
99
+
100
+ def inject_metadata(path, metadata_path, encoder, new_state):
101
+ model = torch.jit.load(path)
102
+ with open(metadata_path, "r") as f:
103
+ metadata = json.load(f)
104
+ metadata["encoder"] = encoder
105
+ metadata["new_state"] = new_state
106
+ extra_files = torch._C.ExtraFilesMap()
107
+ extra_files['metadata.json'] = json.dumps(metadata)
108
+ out_path = os.path.splitext(path)[0] + "_WMETA.pt"
109
+ torch.jit.save(model, out_path, _extra_files=extra_files)
110
+
111
+ def convert(model, path, quantize=False, prune=False, force=False, control=False, ckpt_path=None, encoderX=None):
112
+ if not os.path.exists(path) or force:
113
+ model.eval()
114
+ if quantize:
115
+ model = quantize_model(model)
116
+ if prune:
117
+ model = prune_model(model)
118
+ print_size_of_model(model)
119
+ example_input = torch.zeros(1,300).type(torch.LongTensor)
120
+ example_control = torch.zeros(1,300,3).type(torch.FloatTensor)
121
+ if control:
122
+ outputs = model(input_ids=example_input, control_ids=example_control, past_key_values=None)
123
+ print(len(outputs[1]))
124
+ traced_script_module = torch.jit.trace(model, [example_input,example_control,outputs[1]])
125
+ else:
126
+ outputs = model(input_ids=example_input)
127
+ traced_script_module = torch.jit.trace(model, [example_input, outputs[1]])
128
+
129
+ num_layers = len(outputs[1])
130
+ _,num_heads,_,num_hidden = outputs[1][0][0].detach().numpy().shape
131
+ encoder = encoderX
132
+
133
+ model_metadata = {
134
+ "encoder" : encoder,
135
+ "num_heads" : num_heads,
136
+ "num_hidden" : num_hidden,
137
+ "num_layers" : num_layers,
138
+ "model_dim" : -1,
139
+ "new_state" : True
140
+ }
141
+
142
+ print(model_metadata)
143
+
144
+ extra_files = {}
145
+ extra_files['metadata.json'] = json.dumps(model_metadata)
146
+ torch.jit.save(
147
+ traced_script_module, path, _extra_files=extra_files)
148
+
149
+
150
+ class GPT2LMHeadModelWMeta(GPT2LMHeadModel):
151
+ def extra_repr(self):
152
+ return "trent is the man"
153
+
154
+ if __name__ == "__main__":
155
+
156
+ import argparse
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument("--ckpt_path", type=str, required=True)
159
+ parser.add_argument("--output", type=str, default="")
160
+ parser.add_argument("--metadata_path", type=str, default="")
161
+ parser.add_argument("--config", type=str, default="")
162
+ parser.add_argument("--encoder", type=str, default="NONE")
163
+ parser.add_argument("--init", action="store_true")
164
+ parser.add_argument("--inject", action="store_true")
165
+ parser.add_argument("--new_state", action="store_true")
166
+ parser.add_argument("--quantize", action="store_true")
167
+ parser.add_argument("--prune", action="store_true")
168
+ parser.add_argument("--control", action="store_true")
169
+
170
+ args = parser.parse_args()
171
+
172
+
173
+ if args.inject:
174
+ assert len(args.metadata_path)
175
+ inject_metadata(
176
+ args.ckpt_path, args.metadata_path, args.encoder, True if args.new_state else False)
177
+
178
+ else:
179
+ assert len(args.output)
180
+ if args.init:
181
+ encoder_mode = midigpt.getEncoderType(args.encoder)
182
+ assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER
183
+ encoder = midigpt.getEncoder(encoder_mode)
184
+ vocab_size = encoder.vocab_size()
185
+ if args.control:
186
+ config = GPT2LMHeadModelContConfig().from_json_file(args.config)
187
+ # encoder knows the size of the embedding
188
+ config.n_control_dim = encoder.config.embed_dim
189
+ model_cls = GPT2LMHeadModelCont
190
+
191
+ else:
192
+ config = GPT2Config().from_json_file(args.config)
193
+ config.vocab_size = vocab_size
194
+ model_cls = GPT2LMHeadModel
195
+
196
+ model = model_cls(config)
197
+ else:
198
+ if args.control:
199
+ model = GPT2LMHeadModelCont.from_pretrained(args.ckpt_path, torchscript=True)
200
+ else:
201
+ model = GPT2LMHeadModel.from_pretrained(args.ckpt_path, torchscript=True)
202
+
203
+ convert(model, args.output, quantize=args.quantize, prune=args.prune, control=args.control, ckpt_path=args.ckpt_path, encoderX=args.encoder)
204
+
python_scripts/create_dataset.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import numpy as np
5
+ import csv
6
+ from tqdm import tqdm
7
+ from multiprocessing import Pool
8
+
9
+ from utils import *
10
+
11
+ import sys
12
+ import os
13
+ sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
14
+ import midigpt
15
+
16
+ def worker(args):
17
+ path,sid,labels,nomml,tcjson,encoding = args
18
+ tc = midigpt.TrainConfig()
19
+ tc.from_json(tcjson)
20
+ labels["nomml"] = nomml
21
+
22
+ encoder_mode = midigpt.getEncoderType(encoding)
23
+ assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER
24
+ encoder = midigpt.getEncoder(encoder_mode)
25
+
26
+ try:
27
+ return sid, midigpt.midi_to_json_bytes(path,tc,json.dumps(labels))
28
+ except Exception as e:
29
+ print(e)
30
+ return None,None
31
+
32
+ def load_json(path):
33
+ if not os.path.exists(path):
34
+ return {}
35
+ with open(path, "r") as f:
36
+ return json.load(f)
37
+
38
+ DEFAULT_LABELS = {
39
+ "genre": "GENRE_MUSICMAP_ANY",
40
+ "valence_spotify": -1,
41
+ "energy_spotify": -1,
42
+ "danceability_spotify": -1,
43
+ "tension": []
44
+ }
45
+
46
+ DATA_TYPES = [
47
+ "Drum",
48
+ "Drum+Music",
49
+ "Music-No-Drum"
50
+ ]
51
+
52
+ def load_metadata_labels(genre_data_path, spotify_data_path, tension_data_path):
53
+ data = {}
54
+ genre_data = load_json(genre_data_path)
55
+ spotify_data = load_json(spotify_data_path)
56
+ tension_data = load_json(tension_data_path)
57
+ md5s = list(set(list(genre_data.keys()) + list(spotify_data.keys()) + list(tension_data.keys())))
58
+ for md5 in md5s:
59
+ data[md5] = {}
60
+ if md5 in spotify_data:
61
+ data[md5]["valence_spotify"] = np.mean(spotify_data[md5]["valence"])
62
+ data[md5]["energy_spotify"] = np.mean(spotify_data[md5]["energy"])
63
+ data[md5]["danceability_spotify"] = np.mean(spotify_data[md5]["danceability"])
64
+ else:
65
+ for k,v in DEFAULT_LABELS.items():
66
+ data[md5][k] = v
67
+ data[md5]["genre"] = genre_data.get(md5, DEFAULT_LABELS["genre"])
68
+ data[md5]["tension"] = tension_data.get(md5, DEFAULT_LABELS["tension"])
69
+ return data
70
+
71
+ if __name__ == "__main__":
72
+
73
+ import argparse
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument("--data_dir", type=str, required=True)
76
+ parser.add_argument("--output", type=str, required=True)
77
+ parser.add_argument("--num_bars", type=int, default=4)
78
+ parser.add_argument("--expressive", action="store_true")
79
+ parser.add_argument("--ignore_score", type=bool, default=0)
80
+ parser.add_argument("--nthreads", type=int, default=8)
81
+ parser.add_argument("--max_size", type=int, default=-1)
82
+ parser.add_argument("--genre_data", type=str, default="")
83
+ parser.add_argument("--spotify_data", type=str, default="")
84
+ parser.add_argument("--tension_data", type=str, default="")
85
+ parser.add_argument("--encoding", type=str, default="TRACK_ENCODER")
86
+ parser.add_argument("--resolution", type=int, default=12)
87
+ parser.add_argument("--delta_resolution", type=int, default=1920)
88
+ parser.add_argument("--metadata", type=str, required=True)
89
+ parser.add_argument("--type", type=str, default="Drum+Music")
90
+ parser.add_argument("--test", type=str, default="no")
91
+ args = parser.parse_args()
92
+
93
+ args.ignore_score = bool(args.ignore_score)
94
+ if args.test != "no":
95
+ test_script = True
96
+ else:
97
+ test_script = False
98
+
99
+ assert args.type in DATA_TYPES
100
+ args.type = "-" + args.type + "-"
101
+
102
+ import os
103
+ os.system("taskset -p 0xffff %d" % os.getpid())
104
+
105
+ # multi thread approach takes about 2 minutes
106
+ pool = Pool(args.nthreads)
107
+ output = os.path.splitext(args.output)[0]
108
+ ss=""
109
+ if args.max_size > 0:
110
+ ss=f"_MAX_{args.max_size}"
111
+ if args.expressive:
112
+ output += "/{}_NUM_BARS={}_RESOLUTION_{}_DELTA_{}{}.arr".format(args.encoding,args.num_bars,args.resolution, args.delta_resolution,ss)
113
+ else:
114
+ output += "/{}_NUM_BARS={}_RESOLUTION_{}{}.arr".format(args.encoding,args.num_bars,args.resolution,ss)
115
+ print(output)
116
+ if not test_script:
117
+ jag = midigpt.BytesToFile(output)
118
+
119
+
120
+ paths = list(glob.glob(args.data_dir + "/**/*.mid", recursive=True))
121
+
122
+ import random
123
+ import time
124
+ random.seed(int(time.time()))
125
+
126
+ tc = midigpt.TrainConfig()
127
+ tc.num_bars = args.num_bars
128
+ tc.use_microtiming = args.expressive
129
+ tc.resolution = args.resolution
130
+ tc.delta_resolution = args.delta_resolution
131
+ tc = tc.to_json()
132
+ print(tc)
133
+
134
+ paths_exp = []
135
+ sids_exp = []
136
+ paths_non_exp = []
137
+ sids_non_exp = []
138
+ paths_all = []
139
+ sids_all = []
140
+ nomml_alls = []
141
+ nomml_scores = []
142
+
143
+ try:
144
+ with open(args.metadata) as meta:
145
+ reader = csv.DictReader(meta, delimiter=',')
146
+ for row in tqdm(reader):
147
+ path = row["filepath"]
148
+ nomml = int(row["medianMetricDepth"])
149
+ if (".mid" in path and args.type in path):
150
+ if "-Train-" in path:
151
+ group = 0
152
+ elif "-Val-" in path:
153
+ group = 1
154
+ elif "-Test-" in path:
155
+ group = 2
156
+ else:
157
+ raise RuntimeError("data format incorrect")
158
+ if (nomml < 12):
159
+ paths_non_exp.append(os.path.join(args.data_dir,path))
160
+ sids_non_exp.append(group)
161
+ nomml_scores.append(nomml)
162
+ else:
163
+ paths_exp.append(os.path.join(args.data_dir,path))
164
+ sids_exp.append(group)
165
+ paths_all.append(os.path.join(args.data_dir,path))
166
+ sids_all.append(group)
167
+ nomml_alls.append(nomml)
168
+
169
+ except:
170
+ paths_all = list(glob.glob(args.data_dir + "/**/*.mid", recursive=True))
171
+ for path in paths_all:
172
+ if "-train-" in path:
173
+ sids_all.append(0)
174
+ elif "-valid-" in path:
175
+ sids_all.append(1)
176
+ elif "-test-" in path:
177
+ sids_all.append(2)
178
+ else:
179
+ raise RuntimeError("data format incorrect")
180
+
181
+ nomml_vals = []
182
+ if args.expressive:
183
+ if args.ignore_score:
184
+ paths = paths_exp
185
+ sids = sids_exp
186
+ nomml_vals = [12 for _ in sids]
187
+ else:
188
+ paths = paths_all
189
+ sids = sids_all
190
+ nomml_vals = nomml_alls
191
+ else:
192
+ paths = paths_all
193
+ sids = sids_all
194
+ nomml_vals = nomml_alls
195
+
196
+ metadata_label_data = load_metadata_labels(args.genre_data, args.spotify_data, args.tension_data)
197
+ metadata_labels = [metadata_label_data.get(os.path.splitext(os.path.basename(p))[0],DEFAULT_LABELS) for p in paths]
198
+ print("LOADED {} METADATA LABELS".format(len(metadata_labels)))
199
+
200
+ tcs = [tc for _ in paths]
201
+ encoding = [args.encoding for _ in paths]
202
+ inputs = list(zip(paths,sids,metadata_labels,nomml_vals,tcs,encoding))
203
+ random.shuffle(inputs)
204
+
205
+ for k,v in DEFAULT_LABELS.items():
206
+ print("{} FILES HAVE {} METADATA".format(sum([m[k] != v for m in metadata_labels]),k))
207
+
208
+ if args.max_size > 0:
209
+ inputs = inputs[:args.max_size]
210
+
211
+ if not test_script:
212
+ total_count = 0
213
+ success_count = 0
214
+ pool = Pool(args.nthreads)
215
+ progress_bar = tqdm(pool.imap_unordered(worker, inputs), total=len(inputs))
216
+ for sid,b in progress_bar:
217
+ if b is not None and len(b):
218
+ jag.append_bytes_to_file_stream(b,sid)
219
+ success_count += 1
220
+ total_count += 1
221
+ status_str = "{}/{}".format(success_count,total_count)
222
+ progress_bar.set_description(status_str)
223
+ jag.close()
224
+ else:
225
+ print("Test successful")
226
+ sys.exit(0)
python_scripts/custom_models.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import *
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class GPT2Encoder(GPT2PreTrainedModel):
6
+ def __init__(self, config):
7
+ super().__init__(config)
8
+ self.transformer = GPT2Model(config)
9
+ self.score = nn.Linear(config.n_embd, 128, bias=False)
10
+ self.init_weights()
11
+ self.model_parallel = False
12
+ self.device_map = None
13
+
14
+ negative_importance = torch.tensor(5.33).float()
15
+ negative_threshold = torch.tensor(4.).float()
16
+ entropy_importance = torch.tensor(0.05).float()
17
+ self.register_buffer('negative_importance', negative_importance)
18
+ self.register_buffer('negative_threshold', negative_threshold)
19
+ self.register_buffer('entropy_importance', entropy_importance)
20
+
21
+ def forward(
22
+ self,
23
+ input_ids=None,
24
+ past_key_values=None,
25
+ attention_mask=None,
26
+ token_type_ids=None,
27
+ position_ids=None,
28
+ head_mask=None,
29
+ inputs_embeds=None,
30
+ labels=None,
31
+ use_cache=None,
32
+ output_attentions=None,
33
+ output_hidden_states=None,
34
+ return_dict=None,
35
+ sequence_lengths=None,
36
+ ):
37
+ assert sequence_lengths is not None
38
+ outputs = self.transformer(
39
+ input_ids,
40
+ past_key_values=past_key_values,
41
+ attention_mask=attention_mask,
42
+ token_type_ids=token_type_ids,
43
+ position_ids=position_ids,
44
+ head_mask=head_mask,
45
+ inputs_embeds=inputs_embeds,
46
+ use_cache=use_cache,
47
+ output_attentions=output_attentions,
48
+ output_hidden_states=output_hidden_states,
49
+ return_dict=return_dict,
50
+ )
51
+ hidden_states = outputs[0]
52
+ logits = self.score(hidden_states)
53
+ return logits[range(input_ids.shape[0]),sequence_lengths-1]
54
+
55
+ class GPT2LMHeadModelContConfig(GPT2Config):
56
+ def __init__(
57
+ self,
58
+ n_control_embd=64,
59
+ n_control_dim=3,
60
+ **kwargs
61
+ ):
62
+ super().__init__(**kwargs)
63
+ self.n_control_embd = n_control_embd
64
+ self.n_control_dim = n_control_dim
65
+
66
+ class GPT2LMHeadModelCont(GPT2LMHeadModel):
67
+ def __init__(self, config):
68
+ super().__init__(config)
69
+ token_embd = config.n_embd - config.n_control_embd
70
+ self.wte = nn.Embedding(config.vocab_size, token_embd)
71
+ self.ctrle = nn.Linear(config.n_control_dim, config.n_control_embd)
72
+
73
+ def forward(
74
+ self,
75
+ input_ids=None,
76
+ control_ids=None,
77
+ past_key_values=None,
78
+ attention_mask=None,
79
+ token_type_ids=None,
80
+ position_ids=None,
81
+ head_mask=None,
82
+ inputs_embeds=None,
83
+ labels=None,
84
+ use_cache=None,
85
+ output_attentions=None,
86
+ output_hidden_states=None,
87
+ return_dict=None,
88
+ sequence_lengths=None
89
+ ):
90
+ shape = control_ids.shape
91
+ input_shape = (shape[0]*shape[1],shape[2])
92
+ output_shape = (shape[0],shape[1],self.config.n_control_embd)
93
+ control_embd = self.ctrle(torch.reshape(control_ids, input_shape))
94
+ control_embd = torch.reshape(control_embd, output_shape)
95
+ token_embd = self.wte(input_ids)
96
+ inputs_embeds = torch.cat([token_embd,control_embd], axis=-1)
97
+ return super().forward(
98
+ past_key_values=past_key_values,
99
+ attention_mask=attention_mask,
100
+ inputs_embeds=inputs_embeds,
101
+ labels=labels
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+
106
+ batch_size = 3
107
+
108
+ config = GPT2Config().from_json_file("config/gpt2_tiny.json")
109
+
110
+ #model = GPT2LMHeadModelCont(config)
111
+ model = GPT2Encoder(config)
112
+
113
+ kwargs = {
114
+ "input_ids" : torch.randint(config.vocab_size, size=(batch_size,100)),
115
+ "labels" : torch.randint(config.vocab_size, size=(batch_size,100)),
116
+ #"control_ids" : torch.randint(CONTROL_VOCAB_SIZE, size=(1,100))
117
+ "sequence_lengths" : [99,90,90]
118
+ }
119
+
120
+ out = model.forward(**kwargs)
121
+ print(out.shape)
python_scripts/data_split.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir
2
+ from os.path import isfile, join
3
+ from shutil import move
4
+ import random
5
+ import sys
6
+ from tqdm import tqdm
7
+
8
+ if __name__ == "__main__":
9
+
10
+ root = sys.argv[1]
11
+ new_root = sys.argv[2]
12
+
13
+ # Get all midi files
14
+ print("Getting MIDI files")
15
+ onlyfiles = [f for f in listdir(root)]
16
+ n = len(onlyfiles)
17
+ print("Num files: " + str(n))
18
+
19
+ # Generate random test/train/valid indices
20
+
21
+ idx = [i for i in range(n)]
22
+ split_idx = random.shuffle(idx)
23
+
24
+ train_len = int(0.8 * n)
25
+ test_len = int(0.1 * n)
26
+ valid_len = n - train_len - test_len
27
+
28
+ train_idx = idx[:train_len]
29
+ test_idx = idx[train_len:test_len + train_len]
30
+ valid_idx = idx[test_len + train_len:]
31
+
32
+ # Move files to respective folder
33
+
34
+ o = 0
35
+
36
+ print('Spliting Train Set')
37
+ for i in tqdm(range(train_len)):
38
+ move(join(root, onlyfiles[train_idx[i]]), join(new_root, "train", onlyfiles[train_idx[i]]))
39
+ o += 1
40
+
41
+ print('Spliting Test Set')
42
+ for i in tqdm(range(test_len)):
43
+ move(join(root, onlyfiles[test_idx[i]]), join(new_root, "test", onlyfiles[test_idx[i]]))
44
+ o += 1
45
+
46
+ print('Spliting Validation Set')
47
+ for i in tqdm(range(valid_len)):
48
+ move(join(root, onlyfiles[valid_idx[i]]), join(new_root, "valid", onlyfiles[valid_idx[i]]))
49
+ o += 1
50
+
51
+ print("Succes Test: " + str(o == n))
52
+
53
+ print(join(new_root, "valid", onlyfiles[valid_idx[100]]))
54
+ print(isfile(join(new_root, "valid", onlyfiles[valid_idx[100]])))
python_scripts/losses.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def standard_loss(self, model, inputs, return_outputs=False):
4
+ outputs = model(**inputs)
5
+ if self.args.past_index >= 0:
6
+ self._past = outputs[self.args.past_index]
7
+ loss = outputs[0].mean()
8
+ print("loss : ", loss)
9
+ return (loss,outputs) if return_outputs else loss
10
+
11
+ def hinge_cost(m, a, b):
12
+ dist = m - torch.sqrt(torch.sum((a - b)**2, axis=1))
13
+ return torch.mean(torch.clamp(dist,0,float('inf'))**2)
14
+
15
+ def sim_metric_loss(self, model, inputs, return_outputs=False):
16
+
17
+ # single pass version
18
+ batch_size = len(inputs["input_ids"])//4
19
+ outputs = model(**inputs)
20
+ x_p = outputs[:batch_size]
21
+ x_n = outputs[batch_size:2*batch_size]
22
+ y_p = outputs[2*batch_size:3*batch_size]
23
+ y_n = outputs[3*batch_size:]
24
+
25
+ model_attr = model
26
+ if isinstance(model, torch.nn.DataParallel):
27
+ model_attr = model.module
28
+
29
+ cost_p = torch.mean(torch.sum((x_p - y_p)**2, axis=1))
30
+ cost_n = model_attr.negative_importance*hinge_cost(
31
+ model_attr.negative_threshold, x_n, y_n)
32
+ cost_e = model_attr.entropy_importance*torch.mean(
33
+ torch.sum(x_p**2, axis=1) + torch.sum(y_p**2, axis=1))
34
+ loss = cost_p + cost_n + cost_e
35
+
36
+ print(loss)
37
+ return (loss,None) if return_outputs else loss
python_scripts/train.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
4
+ import midigpt
5
+
6
+ from transformers import *
7
+
8
+ import os
9
+ import json
10
+ import time
11
+ import torch
12
+
13
+ import datetime
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ from subprocess import check_output
18
+
19
+ from losses import sim_metric_loss, standard_loss
20
+ from custom_models import *
21
+ from train_dataset import *
22
+ from transformers import Trainer, TrainingArguments
23
+ from callbacks import MemoryUsageCallback, ProfilerCallback
24
+
25
+ if __name__ == "__main__":
26
+
27
+ import argparse
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--arch", type=str, required=True)
30
+ parser.add_argument("--config", type=str, required=True)
31
+ parser.add_argument("--encoding", type=str, required=True)
32
+ parser.add_argument("--dataset", type=str, required=True)
33
+ parser.add_argument("--pad_value", type=int, default=-100)
34
+
35
+ parser.add_argument("--expressive", action="store_true")
36
+ parser.add_argument("--num_bars", type=int, default=4)
37
+ parser.add_argument("--min_tracks", type=int, default=2)
38
+ parser.add_argument("--max_tracks", type=int, default=12)
39
+ parser.add_argument("--max_seq_len", type=int, default=2048)
40
+ parser.add_argument("--no_max_length", type=int, default=0)
41
+ parser.add_argument("--resolution", type=int, default=12)
42
+ parser.add_argument("--delta_resolution", type=int, default=1920)
43
+ parser.add_argument("--abs_pos_vocab_size", type=int, default=196)
44
+ parser.add_argument("--delta_vocab_size", type=int, default=96)
45
+
46
+ parser.add_argument("--ngpu", type=int, default=4)
47
+ parser.add_argument("--accum_steps", type=int, default=1)
48
+ parser.add_argument("--batch_size", type=int, default=32)
49
+ parser.add_argument("--batches_per_epoch", type=int, default=1000)
50
+ parser.add_argument("--lr", type=float, default=1e-4)
51
+
52
+ parser.add_argument("--overwrite", type=int, default=1)
53
+ parser.add_argument("--save_steps", type=int, default=5000)
54
+ parser.add_argument("--log_steps", type=int, default=100)
55
+ parser.add_argument("--step", type=int, default=0)
56
+ parser.add_argument("--label", type=str, default="version3")
57
+ parser.add_argument("--profiler_steps", type=int, default=50)
58
+
59
+ parser.add_argument("--dry", action="store_true")
60
+ parser.add_argument("--metric", action="store_true")
61
+
62
+ parser.add_argument("--ckpt", type=str, default="")
63
+ parser.add_argument("--ckpt_num", type=int, default=5000)
64
+ parser.add_argument("--output", type=str, default="")
65
+ parser.add_argument("--log", type=str, default="")
66
+
67
+ parser.add_argument("--test_only", action="store_true")
68
+ parser.add_argument("--memory_metrics", action="store_true")
69
+
70
+ args = parser.parse_args()
71
+ args.expressive = (args.encoding == "EXPRESSIVE_ENCODER") and args.expressive
72
+
73
+ dataset_cls = CustomDataset
74
+ loss_fn = standard_loss
75
+
76
+ np.random.seed(int(time.time()))
77
+
78
+ # determine vocab size
79
+ date_str = datetime.datetime.now().strftime('%b_%d_%H_%M')
80
+ encoder_mode = midigpt.getEncoderType(args.encoding)
81
+ assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER
82
+ encoder = midigpt.getEncoder(encoder_mode)
83
+ if args.expressive:
84
+ encoder.set_scheme(args.resolution, args.delta_resolution, args.delta_vocab_size, args.abs_pos_vocab_size)
85
+ vocab_size = encoder.vocab_size()
86
+
87
+ current_git_commit_hash = check_output(["git", "rev-parse", "HEAD"], text=True).strip()
88
+
89
+ load_checkpoint = False
90
+ if args.ckpt == "":
91
+ name = "_".join([args.encoding, args.arch, args.label, date_str, "num_bars", str(args.num_bars), str(args.max_tracks), "GIT_HASH", current_git_commit_hash])
92
+ else:
93
+ name = args.ckpt
94
+ load_checkpoint = True
95
+
96
+ if args.dry:
97
+ while True:
98
+ dataset = dataset_cls(split_id=0, is_training=True, **vars(args))
99
+ for batch in tqdm(dataset,smoothing=0):
100
+ np_inputs = batch["input_ids"].detach().numpy()
101
+ print( [encoder.pretty(t) for t in np_inputs[0][:100]] )
102
+ print( {k:v.shape for k,v in batch.items()} )
103
+
104
+ if os.getenv("SLURM_TMPDIR") is not None:
105
+ # we are on compute canada and should attempt to copy
106
+ # dataset to tmpdir for faster access
107
+ from shutil import copyfile
108
+ tmpdir = os.getenv("SLURM_TMPDIR")
109
+ dataset_path = os.path.join(tmpdir, os.path.basename(args.dataset))
110
+ if not os.path.exists(dataset_path):
111
+ copyfile(args.dataset, dataset_path)
112
+ copyfile(args.dataset + ".header", dataset_path + ".header")
113
+ args.dataset = dataset_path
114
+
115
+ # setup datasets
116
+ train_dataset = dataset_cls(split_id=0, is_training=True, **vars(args))
117
+ eval_dataset = dataset_cls(split_id=2, is_training=False, overload_batches_per_epoch=1, **vars(args))
118
+ Trainer.get_train_dataloader = lambda *_args,**_kwargs: train_dataset
119
+ Trainer.get_eval_dataloader = lambda *_args,**_kwargs: eval_dataset
120
+ Trainer.compute_loss = loss_fn
121
+
122
+ print("MODEL NAME : " + name)
123
+ print("VOCAB SIZE : " + str(vocab_size))
124
+ print("ARGS : " + json.dumps(vars(args),indent=4))
125
+ print("MODEL CONFIG : " + json.dumps(json.load(open(args.config,"r")),indent=4))
126
+ print("ENCODER CONFIG : " + json.dumps(encoder.config.ToJson(),indent=4))
127
+
128
+ logging_dir = os.path.join(args.log, "{}".format(name))
129
+ output_dir = os.path.join(args.output, "checkpoints/{}".format(name))
130
+
131
+ print("LOGGING PATH : " + logging_dir)
132
+ print("OUTPUT PATH : " + output_dir)
133
+
134
+ os.makedirs(logging_dir, exist_ok=True)
135
+ os.makedirs(output_dir, exist_ok=True)
136
+
137
+ # =================================================================
138
+ # model selection
139
+
140
+ if args.arch == "gpt2":
141
+ config = GPT2Config().from_json_file(args.config)
142
+ model_cls = GPT2LMHeadModel
143
+ elif args.arch == "xl":
144
+ config = TransfoXLConfig().from_json_file(args.config)
145
+ model_cls = TransfoXLLMHeadModel
146
+ elif args.arch == "metric":
147
+ config = GPT2Config().from_json_file(args.config)
148
+ model_cls = GPT2Encoder
149
+ elif args.arch == "control":
150
+ config = GPT2LMHeadModelContConfig().from_json_file(args.config)
151
+ # encoder knows the size of the embedding
152
+ config.n_control_dim = encoder.config.embed_dim
153
+ model_cls = GPT2LMHeadModelCont
154
+ elif args.arch == "bert":
155
+ config = BertConfig().from_json_file(args.config)
156
+ model_cls = BertForMaskedLM
157
+ else:
158
+ raise NotImplementedError
159
+
160
+ config.vocab_size = vocab_size
161
+ print("MODEL CONFIG : " + str(config))
162
+
163
+
164
+ if len(args.ckpt.strip()) == 0:
165
+ print('Model initialization')
166
+ ckpt_path = None
167
+ model = model_cls(config)
168
+ else:
169
+ try:
170
+ print('Trying to load checkpoint')
171
+ ckpt_path = os.path.join(output_dir, f"checkpoint-{args.ckpt_num}")
172
+ model = model_cls.from_pretrained(ckpt_path)
173
+ except Exception as e:
174
+ print(e)
175
+ print('Returning to default model initialization')
176
+ model = model_cls(config)
177
+
178
+
179
+ # Create Memory metrics callback
180
+
181
+ # =================================================================
182
+ # training
183
+
184
+ training_args = TrainingArguments(
185
+ logging_dir=logging_dir,
186
+ report_to="tensorboard",
187
+ output_dir=output_dir,
188
+ overwrite_output_dir=bool(args.overwrite),
189
+ num_train_epochs=(500000/args.batches_per_epoch)*args.accum_steps,
190
+ logging_steps=args.log_steps,
191
+ save_steps=args.save_steps,
192
+ save_total_limit=None,
193
+ learning_rate=args.lr,
194
+ gradient_accumulation_steps=args.accum_steps,
195
+ per_device_train_batch_size=args.batch_size//args.ngpu//args.accum_steps,
196
+ per_device_eval_batch_size=args.batch_size//args.ngpu//args.accum_steps,
197
+ evaluation_strategy="epoch",
198
+ prediction_loss_only=True,
199
+ skip_memory_metrics=True
200
+ )
201
+
202
+ # For custom memory metrics, don't work and multiply by 100 training time!!!
203
+ if args.memory_metrics:
204
+ callbacks = [MemoryUsageCallback, ProfilerCallback]
205
+ else:
206
+ callbacks = []
207
+
208
+ trainer = Trainer(
209
+ model=model,
210
+ args=training_args,
211
+ data_collator=None,
212
+ train_dataset=None,
213
+ eval_dataset=None,
214
+ callbacks=callbacks
215
+ )
216
+
217
+ trainer.train_dataset = train_dataset
218
+ trainer.eval_dataset = eval_dataset
219
+
220
+ if not args.test_only:
221
+ trainer.train(ckpt_path)
222
+ else:
223
+ model = trainer._wrap_model(trainer.model)
224
+ model.save_pretrained(output_dir)
python_scripts/train_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from transformers import Trainer, TrainingArguments
2
+
3
+ import os
4
+ import json
5
+ import time
6
+ import torch
7
+ import tqdm
8
+ #from torch.utils.data import Dataset
9
+
10
+ import datetime
11
+ import numpy as np
12
+
13
+ import sys
14
+ sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
15
+ import midigpt
16
+
17
+ class CustomDataset:
18
+ def __init__(self, split_id=0, is_training=True, batch_size=32, dataset=None, num_bars=4, min_tracks=2, max_tracks=12, max_seq_len=2048, expressive=False, no_max_length=False, resolution=12, encoding=None, pad_value=-100, arch="gpt2", accum_steps=1, batches_per_epoch=1000, overload_batches_per_epoch=None, **kwargs):
19
+ # settings
20
+ self.is_training = is_training
21
+ self.batch_size = batch_size // accum_steps
22
+ self.split_id = split_id
23
+ self.max_seq_len = max_seq_len
24
+ self.batches_per_epoch = batches_per_epoch if overload_batches_per_epoch is None else overload_batches_per_epoch
25
+ self.dataset = list(range(self.batches_per_epoch)) # number of examples ??
26
+ self.pad_value = pad_value
27
+ self.arch = arch
28
+
29
+ # create dataloader
30
+ self.dataloader = midigpt.Jagged(dataset)
31
+ self.dataloader.set_num_bars(num_bars)
32
+ self.dataloader.set_min_tracks(min_tracks)
33
+ self.dataloader.set_max_tracks(max_tracks)
34
+ self.dataloader.set_max_seq_len(max_seq_len)
35
+ seed = np.random.randint(2**20)
36
+ self.dataloader.set_seed(seed)
37
+ self.encoder_mode = midigpt.getEncoderType(encoding)
38
+
39
+ # create train_config
40
+ self.tc = midigpt.TrainConfig()
41
+ self.tc.num_bars = num_bars
42
+ self.tc.min_tracks = min_tracks
43
+ self.tc.max_tracks = max_tracks
44
+ self.tc.use_microtiming = expressive
45
+ self.tc.no_max_length = no_max_length
46
+ self.tc.resolution = resolution
47
+
48
+ self.current = 0
49
+
50
+ def _get_batch(self):
51
+ input_ids, mask = self.dataloader.read_batch_v2(
52
+ self.batch_size, self.split_id, self.encoder_mode, self.tc)
53
+ input_ids = np.array(input_ids)
54
+ mask = np.array(mask)
55
+ labels = np.copy(input_ids)
56
+ labels += (1-mask) * self.pad_value # set masked tokens to pad_value
57
+ batch = {
58
+ "input_ids" : torch.from_numpy(input_ids),
59
+ "attention_mask" : torch.from_numpy(mask),
60
+ "labels" : torch.from_numpy(labels)
61
+ }
62
+ if self.arch == "xl":
63
+ batch.pop("attention_mask")
64
+ assert np.all(np.sum(mask,axis=1)==self.max_seq_len)
65
+ if self.arch == "bert":
66
+ batch.pop("labels")
67
+ return batch
68
+
69
+ def _get_batch_test(self):
70
+ inputs = torch.ones((32,800), dtype=torch.int64)
71
+ return {
72
+ "input_ids" : inputs,
73
+ "labels" : inputs
74
+ }
75
+
76
+ def __iter__(self):
77
+ self.current = 0
78
+ return self
79
+
80
+ def __next__(self):
81
+ self.current += 1
82
+ if self.current <= self.batches_per_epoch:
83
+ while True:
84
+ try:
85
+ return self._get_batch()
86
+ except Exception as e:
87
+ print("ERROR IN BATCHER : ", e)
88
+ raise StopIteration
89
+
90
+ def __len__(self):
91
+ return self.batches_per_epoch
92
+
93
+ def pad(seqs, pad_value):
94
+ seqlens = np.array([len(seq) for seq in seqs])
95
+ maxlen = np.max(seqlens)
96
+ return np.array([np.pad(seq, (0,maxlen-len(seq)), mode="constant", constant_values=pad_value) for seq in seqs]), seqlens
python_scripts/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import jsonlines
3
+ from multiprocessing import Pool
4
+ from tqdm import tqdm
5
+
6
+ def load_json(path):
7
+ with open(path,"r") as f:
8
+ return json.load(f)
9
+
10
+ def dump_json(x, path):
11
+ with open(path,"w") as f:
12
+ json.dump(x,f,indent=4)
13
+
14
+ def apply_func(x, func):
15
+ pool = Pool(8)
16
+ for inval,outval in tqdm(pool.imap_unordered(func,x),total=len(x)):
17
+ yield inval,outval
18
+
19
+ def load_jsonl(path,max_items=None):
20
+ with jsonlines.open(path) as reader:
21
+ for ii,item in enumerate(reader):
22
+ yield item
23
+ if max_items is not None and ii >= max_items:
24
+ break
25
+
26
+ def dump_jsonl(data, path):
27
+ assert isinstance(data, list)
28
+ with jsonlines.open(path, mode="w") as wr:
29
+ for item in tqdm(data,leave=False):
30
+ wr.write(item)
31
+
32
+ class dump_jsonl_multistage:
33
+ def __init__(self, path, mode="a"):
34
+ self.wr = jsonlines.open(path, mode=mode, flush=True)
35
+ def add(self, item):
36
+ self.wr.write(item)
37
+ def extend(self, items):
38
+ for item in items:
39
+ self.add(item)
40
+ def close(self):
41
+ self.wr.close()
python_scripts_for_testing/midigpt_gen.mid ADDED
Binary file (591 Bytes). View file
 
python_scripts_for_testing/mtest.mid ADDED
Binary file (455 Bytes). View file
 
python_scripts_for_testing/pythoninferencetest.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
3
+ import midigpt
4
+ import json
5
+ import random
6
+
7
+ if __name__ == "__main__":
8
+
9
+ import argparse
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--midi", type=str, required=True)
12
+ parser.add_argument("--ckpt", type=str, required=True)
13
+ parser.add_argument("--out", type=str, default='')
14
+ args = parser.parse_args()
15
+
16
+ ckpt = args.ckpt
17
+ midi_input = args.midi
18
+ if args.out != '':
19
+ midi_dest = args.out
20
+ else:
21
+ midi_dest = os.path.join(os.path.split(midi_input)[0], 'midigpt_gen.mid')
22
+ e = midigpt.ExpressiveEncoder()
23
+ midi_json_input = json.loads(e.midi_to_json(midi_input))
24
+ valid_status={'tracks':
25
+ [
26
+ {
27
+ 'track_id': 0,
28
+ 'temperature' : 0.5,
29
+ 'instrument': 'acoustic_grand_piano',
30
+ 'density': 10,
31
+ 'track_type': 10,
32
+ 'ignore': False,
33
+ 'selected_bars': [False, False, True, False ],
34
+ 'min_polyphony_q': 'POLYPHONY_ANY',
35
+ 'max_polyphony_q': 'POLYPHONY_ANY',
36
+ 'autoregressive': False,
37
+ 'polyphony_hard_limit': 9
38
+ }
39
+ ]
40
+ }
41
+ parami={
42
+ 'tracks_per_step': 1,
43
+ 'bars_per_step': 1,
44
+ 'model_dim': 4,
45
+ 'percentage': 100,
46
+ 'batch_size': 1,
47
+ 'temperature': 1.0,
48
+ 'max_steps': 200,
49
+ 'polyphony_hard_limit': 6,
50
+ 'shuffle': True,
51
+ 'verbose': True,
52
+ 'ckpt': ckpt,
53
+ 'sampling_seed': -1,
54
+ 'mask_top_k': 0
55
+ }
56
+
57
+ piece = json.dumps(midi_json_input)
58
+ status = json.dumps(valid_status)
59
+ param = json.dumps(parami)
60
+ callbacks = midigpt.CallbackManager()
61
+ max_attempts = 3
62
+ midi_str = midigpt.sample_multi_step(piece, status, param, max_attempts, callbacks)
63
+ midi_str=midi_str[0]
64
+ midi_json = json.loads(midi_str)
65
+
66
+ e = midigpt.ExpressiveEncoder()
67
+ e.json_to_midi(midi_str, midi_dest)
setup.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #CODE IN PROGRESS, NOT WORKING YET
2
+
3
+ from setuptools import setup, Extension
4
+ from setuptools.command.build_ext import build_ext
5
+ import sys
6
+ import setuptools
7
+
8
+ __version__ = '0.0.1'
9
+
10
+ class get_pybind_include(object):
11
+ """Helper class to determine the pybind11 include path"""
12
+
13
+ def __init__(self, user=False):
14
+ self.user = user
15
+
16
+ def __str__(self):
17
+ import pybind11
18
+ return pybind11.get_include(self.user)
19
+
20
+ ext_modules = [
21
+ Extension(
22
+ 'midigpt',
23
+ ['src/lib.cpp','src/common/data_structures/train_config.cpp','src/dataset_creation/compression/lz4.c',
24
+ 'src/dataset_creation/dataset_manipulation/bytes_to_file.cpp'],
25
+ include_dirs=[
26
+ get_pybind_include(),
27
+ get_pybind_include(user=True)
28
+ ],
29
+ language='c++'
30
+ ),
31
+ ]
32
+
33
+ class BuildExt(build_ext):
34
+ """A custom build extension for adding compiler-specific options."""
35
+ c_opts = {
36
+ 'msvc': ['/EHsc'],
37
+ 'unix': [],
38
+ }
39
+
40
+ if sys.platform == 'darwin':
41
+ c_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7']
42
+
43
+ def build_extensions(self):
44
+ ct = self.compiler.compiler_type
45
+ opts = self.c_opts.get(ct, [])
46
+ if ct == 'unix':
47
+ opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version())
48
+ opts.append('-std=c++20')
49
+ elif ct == 'msvc':
50
+ opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version())
51
+ for ext in self.extensions:
52
+ ext.extra_compile_args = opts
53
+ build_ext.build_extensions(self)
54
+
55
+ setup(
56
+ name='midigpt',
57
+ version=__version__,
58
+ author='Jeff Ens, Rafael Arias',
59
+ author_email='[email protected]',
60
+ url='',
61
+ description='A Python wrapper for midigpt project',
62
+ long_description='',
63
+ ext_modules=ext_modules,
64
+ install_requires=['pybind11>=2.5.0'],
65
+ cmdclass={'build_ext': BuildExt},
66
+ zip_safe=False,
67
+ )
src/common/data_structures/encoder_config.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <tuple>
5
+ #include <map>
6
+ #include <random>
7
+
8
+ namespace data_structures {
9
+ class EncoderConfig {
10
+ public:
11
+ EncoderConfig() {
12
+ both_in_one = false;
13
+ unquantized = false;
14
+ do_multi_fill = false;
15
+ use_velocity_levels = false;
16
+ use_microtiming = false;
17
+ transpose = 0;
18
+ resolution = 12;
19
+ decode_resolution = resolution;
20
+ decode_final = false;
21
+ delta_resolution = 1920;
22
+ }
23
+
24
+ std::map<std::string, std::string> ToJson() {
25
+ std::map<std::string, std::string> json_config;
26
+
27
+ json_config["both_in_one"] = std::to_string((int)both_in_one);
28
+ json_config["unquantized"] = std::to_string((int)unquantized);
29
+ json_config["do_multi_fill"] = std::to_string((int)do_multi_fill);
30
+ json_config["use_velocity_levels"] = std::to_string((int)use_velocity_levels);
31
+ json_config["use_microtiming"] = std::to_string((int)use_microtiming);
32
+ json_config["transpose"] = std::to_string(transpose);
33
+ json_config["resolution"] = std::to_string(resolution);
34
+ json_config["decode_resolution"] = std::to_string(decode_resolution);
35
+ json_config["decode_final"] = std::to_string((int)decode_final);
36
+ json_config["delta_resolution"] = std::to_string(delta_resolution);
37
+ return json_config;
38
+ }
39
+
40
+ void FromJson(const std::map<std::string, std::string>& json_config) {
41
+ try {
42
+ both_in_one = (bool)std::stoi(json_config.at("both_in_one"));
43
+ unquantized = (bool)std::stoi(json_config.at("unquantized"));
44
+ do_multi_fill = (bool)std::stoi(json_config.at("do_multi_fill"));
45
+ use_velocity_levels = (bool)std::stoi(json_config.at("use_velocity_levels"));
46
+ use_microtiming = (bool)std::stoi(json_config.at("use_microtiming"));
47
+ transpose = std::stoi(json_config.at("transpose"));
48
+ resolution = std::stoi(json_config.at("resolution"));
49
+ decode_resolution = std::stoi(json_config.at("decode_resolution"));
50
+ decode_final = (bool)std::stoi(json_config.at("decode_final"));
51
+ delta_resolution = std::stoi(json_config.at("delta_resolution"));
52
+ } catch (const std::out_of_range& e) {
53
+ throw std::invalid_argument("Missing required key in JSON config: " + std::string(e.what()));
54
+ } catch (const std::invalid_argument& e) {
55
+ throw std::invalid_argument("Invalid value type in JSON config: " + std::string(e.what()));
56
+ }
57
+ }
58
+
59
+ int delta_to_step(int delta, int res) {
60
+ if (!use_microtiming) {
61
+ return 0;
62
+ } else {
63
+ return (int)(delta * res / delta_resolution);
64
+ }
65
+ }
66
+
67
+ int step_to_delta(float step, int res) {
68
+ if (!use_microtiming) {
69
+ return 0;
70
+ } else {
71
+ return round(delta_resolution * step / res);
72
+ }
73
+ }
74
+
75
+ int step_to_delta(int step, int res) {
76
+ if (!use_microtiming) {
77
+ return 0;
78
+ } else {
79
+ return round(delta_resolution * step / res);
80
+ }
81
+ }
82
+
83
+ bool both_in_one;
84
+ bool unquantized;
85
+ bool do_multi_fill;
86
+ bool use_velocity_levels;
87
+ bool use_microtiming;
88
+ int transpose;
89
+ int resolution;
90
+ int decode_resolution;
91
+ bool decode_final;
92
+ int delta_resolution;
93
+ std::set<std::tuple<int, int>> multi_fill;
94
+ };
95
+ }
src/common/data_structures/token_sequence.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include "midi.pb.h"
5
+ #include "../encoder/representation.h"
6
+
7
+ namespace data_structures {
8
+
9
+ class TokenSequence {
10
+ public:
11
+ TokenSequence(const std::shared_ptr<encoder::REPRESENTATION> &rep) {
12
+ bar_num = 0;
13
+ track_num = 0;
14
+ }
15
+ void push_back( int token ) {
16
+ tokens.push_back( token );
17
+ }
18
+
19
+ void insert( std::vector<int> &tokens) {
20
+ for (auto token : tokens) {
21
+ push_back(token);
22
+ }
23
+ }
24
+
25
+ void on_track_start(midi::Piece *x, const std::shared_ptr<encoder::REPRESENTATION> &rep) {
26
+ track_num++;
27
+ bar_num = 0;
28
+ }
29
+ void on_bar_start(midi::Piece *x, const std::shared_ptr<encoder::REPRESENTATION> &rep) {
30
+
31
+ bar_num++;
32
+ }
33
+ int bar_num;
34
+ int track_num;
35
+
36
+ std::vector<int> tokens;
37
+ };
38
+ }
src/common/data_structures/track_type.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "midi.pb.h"
4
+
5
+ // START OF NAMESPACE
6
+ namespace data_structures {
7
+
8
+ std::map<midi::TRACK_TYPE,bool> TRACK_TYPE_IS_DRUM = {
9
+ {midi::AUX_DRUM_TRACK, true},
10
+ {midi::AUX_INST_TRACK, false},
11
+ {midi::STANDARD_TRACK, false},
12
+ {midi::STANDARD_DRUM_TRACK, true},
13
+ };
14
+
15
+ bool is_drum_track(int tt) {
16
+ return TRACK_TYPE_IS_DRUM[static_cast<midi::TRACK_TYPE>(tt)];
17
+ }
18
+
19
+ }
20
+ // END OF NAMESPACE
src/common/data_structures/train_config.cpp ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "train_config.h"
2
+
3
+
4
+ namespace data_structures {
5
+
6
+ TrainConfig::TrainConfig() {
7
+ num_bars = 4;
8
+ min_tracks = 1;
9
+ max_tracks = 12;
10
+ max_mask_percentage = 0.75;
11
+ use_microtiming = false;
12
+ microtiming = 0.9;
13
+ no_max_length = false;
14
+ resolution = 12;
15
+ delta_resolution = 1920;
16
+ decode_resolution = delta_resolution;
17
+ }
18
+
19
+ std::map<std::string, std::string> TrainConfig::ToJson() {
20
+ std::map<std::string, std::string> json_config;
21
+ json_config["num_bars"] = std::to_string(num_bars);
22
+ json_config["min_tracks"] = std::to_string(min_tracks);
23
+ json_config["max_tracks"] = std::to_string(max_tracks);
24
+ json_config["max_mask_percentage"] = std::to_string(max_mask_percentage);
25
+ json_config["use_microtiming"] = std::to_string((int)use_microtiming);
26
+ json_config["microtiming"] = std::to_string(microtiming);
27
+ json_config["no_max_length"] = std::to_string((int)no_max_length);
28
+ json_config["resolution"] = std::to_string(resolution);
29
+ json_config["decode_resolution"] = std::to_string(decode_resolution);
30
+ json_config["delta_resolution"] = std::to_string(delta_resolution);
31
+ return json_config;
32
+ }
33
+
34
+ void TrainConfig::FromJson(std::map<std::string, std::string>& json_config) {
35
+ num_bars = stoi(json_config["num_bars"]);
36
+ min_tracks = stoi(json_config["min_tracks"]);
37
+ max_tracks = stoi(json_config["max_tracks"]);
38
+ max_mask_percentage = stof(json_config["max_mask_percentage"]);
39
+ microtiming = stoi(json_config["microtiming"]);
40
+ use_microtiming = (bool)stoi(json_config["use_microtiming"]);
41
+ no_max_length = (bool)stoi(json_config["no_max_length"]);
42
+ resolution = stoi(json_config["resolution"]);
43
+ decode_resolution = stoi(json_config["decode_resolution"]);
44
+ delta_resolution = stoi(json_config["delta_resolution"]);
45
+ }
46
+ }