Paul Triana
commited on
Commit
·
6229e10
1
Parent(s):
61c7027
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +87 -0
- .gitmodules +6 -0
- CMakeLists.txt +87 -0
- CMakeSettings.json +28 -0
- README.md +176 -0
- create_dataset_compute_canada.sh +77 -0
- create_python_library.sh +232 -0
- data_split.sh +16 -0
- include/dataset_creation/compression/lz4.h +764 -0
- include/dataset_creation/dataset_manipulation/bytes_to_file.h +27 -0
- libraries/midifile +1 -0
- libraries/protobuf/CMakeLists.txt +28 -0
- libraries/protobuf/CMakeSettings.json +29 -0
- libraries/protobuf/include/proto_library.h +5 -0
- libraries/protobuf/src/enum.proto +389 -0
- libraries/protobuf/src/feature_extraction.proto +92 -0
- libraries/protobuf/src/midi.proto +429 -0
- libraries/protobuf/src/midi_internal.proto +125 -0
- libraries/protobuf/src/track_type.proto +12 -0
- libraries/pybind11 +1 -0
- libraries/torch/CMakeLists.txt +26 -0
- libraries/torch/CMakeSettings.json +29 -0
- libraries/torch/include/torch_library.h +5 -0
- libraries/torch/src/torch_library.cpp +10 -0
- midigpt_setup_helper.sh +182 -0
- models/model.zip +3 -0
- pip_requirements/common_requirements.txt +1 -0
- pip_requirements/create_dataset_requirements.txt +4 -0
- pip_requirements/inference_requirements.txt +1 -0
- pip_requirements/train_requirements.txt +135 -0
- python_scripts/config/bert.json +7 -0
- python_scripts/config/bert_tiny.json +7 -0
- python_scripts/config/gpt2.json +7 -0
- python_scripts/config/gpt2_tiny.json +7 -0
- python_scripts/convert.py +204 -0
- python_scripts/create_dataset.py +226 -0
- python_scripts/custom_models.py +121 -0
- python_scripts/data_split.py +54 -0
- python_scripts/losses.py +37 -0
- python_scripts/train.py +224 -0
- python_scripts/train_dataset.py +96 -0
- python_scripts/utils.py +41 -0
- python_scripts_for_testing/midigpt_gen.mid +0 -0
- python_scripts_for_testing/mtest.mid +0 -0
- python_scripts_for_testing/pythoninferencetest.py +67 -0
- setup.py +67 -0
- src/common/data_structures/encoder_config.h +95 -0
- src/common/data_structures/token_sequence.h +38 -0
- src/common/data_structures/track_type.h +20 -0
- src/common/data_structures/train_config.cpp +46 -0
.gitignore
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .gitignore
|
2 |
+
|
3 |
+
build.sh
|
4 |
+
|
5 |
+
# feature_extraction stuff
|
6 |
+
python_experiment_scripts/**/*.json
|
7 |
+
python_experiment_scripts/**/*.csv
|
8 |
+
python_experiment_scripts/**/*.pdf
|
9 |
+
python_experiment_scripts/**/*.numbers
|
10 |
+
python_experiment_scripts/**/*.txt
|
11 |
+
python_experiment_scripts/examples/**/*.mid
|
12 |
+
|
13 |
+
python_scripts_for_testing/examples/*.mid
|
14 |
+
python_scripts_for_testing/examples/
|
15 |
+
|
16 |
+
# Ignore compilation folders / dependencies
|
17 |
+
#midifile/
|
18 |
+
libtorch/
|
19 |
+
libtorch_cpu/
|
20 |
+
#pybind11/
|
21 |
+
build/
|
22 |
+
python_lib/
|
23 |
+
CMakeFiles/
|
24 |
+
CMakeScripts/
|
25 |
+
CMakeCache.txt
|
26 |
+
cmake_install.cmake
|
27 |
+
Debug/
|
28 |
+
|
29 |
+
# Ignore the libtorch directory
|
30 |
+
/libraries/libtorch
|
31 |
+
/libraries/libtorch_cpu
|
32 |
+
#/libraries/pybind11
|
33 |
+
#/libraries/midifile
|
34 |
+
|
35 |
+
/python_lib/
|
36 |
+
/python_scripts/training_files
|
37 |
+
|
38 |
+
# ignore dataset files
|
39 |
+
*.arr
|
40 |
+
*.arr.header
|
41 |
+
|
42 |
+
# ignore dstore
|
43 |
+
.DS_Store
|
44 |
+
|
45 |
+
# Ignore the model.pt file
|
46 |
+
**/*.pt
|
47 |
+
*.pt
|
48 |
+
|
49 |
+
#ignore visual studio metadata
|
50 |
+
.vs/
|
51 |
+
.vscode
|
52 |
+
|
53 |
+
#ignore build directory
|
54 |
+
/out/
|
55 |
+
|
56 |
+
*.out
|
57 |
+
|
58 |
+
#ignore random files
|
59 |
+
2539bytes.txt
|
60 |
+
CMakeLists_CPP.txt
|
61 |
+
|
62 |
+
/libraries/protobuf/.vs/
|
63 |
+
/libraries/protobuf/out/
|
64 |
+
|
65 |
+
/libraries/torch/.vs/
|
66 |
+
/libraries/torch/out/
|
67 |
+
|
68 |
+
# exceptions
|
69 |
+
!python_experiment_scripts/onset_density_test.pdf
|
70 |
+
!python_experiment_scripts/onset_density_test.json
|
71 |
+
!python_experiment_scripts/onset_polyphony_test.json
|
72 |
+
!python_experiment_scripts/musicmap_genre_data.json
|
73 |
+
|
74 |
+
!python_experiment_scripts/onset_polyphony_test_mono.json
|
75 |
+
!python_experiment_scripts/onset_polyphony_test_poly.json
|
76 |
+
!python_experiment_scripts/onset_polyphony_v_original_test_mono.json
|
77 |
+
!python_experiment_scripts/onset_polyphony_v_original_test_poly.json
|
78 |
+
|
79 |
+
# pycache folders
|
80 |
+
*/__pycache__/*
|
81 |
+
|
82 |
+
#training outputs
|
83 |
+
*/checkpoints/*
|
84 |
+
*/logs/*
|
85 |
+
*/midigpt.cpython-38-x86_64-linux-gnu.so
|
86 |
+
|
87 |
+
notes/*
|
.gitmodules
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "libraries/pybind11"]
|
2 |
+
path = libraries/pybind11
|
3 |
+
url = https://github.com/pybind/pybind11.git
|
4 |
+
[submodule "libraries/midifile"]
|
5 |
+
path = libraries/midifile
|
6 |
+
url = https://github.com/craigsapp/midifile.git
|
CMakeLists.txt
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 3.8)
|
2 |
+
|
3 |
+
SET(CMAKE_CXX_STANDARD 20)
|
4 |
+
SET(CMAKE_CXX_STANDARD_REQUIRED ON)
|
5 |
+
SET(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
6 |
+
#we add the following line to fix a linkage issue between torch and midifile
|
7 |
+
#https://stackoverflow.com/questions/68922557/c-linker-error-undefined-reference-when-linking-package-libtorch-and-shared
|
8 |
+
#add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
9 |
+
|
10 |
+
project(midigpt)
|
11 |
+
|
12 |
+
option(compute_canada "Build for Compute Canada" OFF)
|
13 |
+
option(mac_os "Build for Mac OS" OFF)
|
14 |
+
option(no_torch "No Torch" OFF)
|
15 |
+
option(no_pybind "No Pybind" OFF)
|
16 |
+
option(trace "Trace" OFF)
|
17 |
+
|
18 |
+
#Find the necessary packages to be able to link the libraries correctly
|
19 |
+
find_package(Protobuf REQUIRED)
|
20 |
+
include_directories(${Protobuf_INCLUDE_DIRS})
|
21 |
+
|
22 |
+
if(no_torch)
|
23 |
+
add_definitions(-DNO_TORCH)
|
24 |
+
endif()
|
25 |
+
|
26 |
+
if(NOT no_torch)
|
27 |
+
if(mac_os)
|
28 |
+
message("USING PYTHON PYTORCH INSTEAD")
|
29 |
+
else()
|
30 |
+
set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libraries/libtorch/")
|
31 |
+
endif()
|
32 |
+
find_package(Torch REQUIRED)
|
33 |
+
|
34 |
+
# This is necessary to avoid a symbol linkage error https://github.com/pytorch/pytorch/issues/38122
|
35 |
+
# https://github.com/DeepVAC/libdeepvac/blob/master/python/CMakeLists.txt
|
36 |
+
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
37 |
+
endif()
|
38 |
+
|
39 |
+
if(compute_canada)
|
40 |
+
include_directories("/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx512/Core/python/3.8.2/include/python3.8")
|
41 |
+
endif()
|
42 |
+
|
43 |
+
#Add the directories of libraries so the project can CMake them too
|
44 |
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/protobuf)
|
45 |
+
if(NOT no_torch)
|
46 |
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/torch)
|
47 |
+
endif()
|
48 |
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/pybind11)
|
49 |
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/midifile)
|
50 |
+
|
51 |
+
#https://stackoverflow.com/questions/8934295/add-source-in-a-subdirectory-to-a-cmake-project/54285898#54285898
|
52 |
+
#https://crascit.com/2016/01/31/enhanced-source-file-handling-with-target_sources/
|
53 |
+
|
54 |
+
|
55 |
+
set(SRCS
|
56 |
+
src/common/data_structures/train_config.cpp
|
57 |
+
src/dataset_creation/compression/lz4.c
|
58 |
+
src/dataset_creation/dataset_manipulation/bytes_to_file.cpp
|
59 |
+
src/common/encoder/encoder_all.h
|
60 |
+
src/lib.cpp
|
61 |
+
)
|
62 |
+
PYBIND11_ADD_MODULE(midigpt ${SRCS})
|
63 |
+
|
64 |
+
#Adding include folders of libraries to our target so we can reference them with #include
|
65 |
+
#Add subdirectory adds those to main project so they can be CMAKEd. Include dirs allows us to reference functions in main.
|
66 |
+
target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/protobuf/include)
|
67 |
+
if (NOT no_torch)
|
68 |
+
target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/torch/include)
|
69 |
+
endif()
|
70 |
+
target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/midifile/include)
|
71 |
+
|
72 |
+
#Linking all the libraries
|
73 |
+
target_link_libraries(midigpt PRIVATE midigpt_proto) #Our protobuf custom library
|
74 |
+
target_link_libraries(midigpt PRIVATE midifile)
|
75 |
+
if (NOT no_torch)
|
76 |
+
target_link_libraries(midigpt PRIVATE midigpt_torch) #Our torch custom library
|
77 |
+
#This is necessary to avoid a symbol linkage error https://github.com/pytorch/pytorch/issues/38122
|
78 |
+
target_link_libraries(midigpt PRIVATE "${TORCH_LIBRARIES}" ${TORCH_PYTHON_LIBRARY})
|
79 |
+
endif()
|
80 |
+
|
81 |
+
if (trace)
|
82 |
+
add_library(tracer STATIC src/trace.cpp)
|
83 |
+
target_link_libraries(midigpt PRIVATE tracer)
|
84 |
+
target_compile_options(midigpt PRIVATE -Wall -Wextra -Wpedantic -finstrument-functions)
|
85 |
+
elseif(NOT WIN32)
|
86 |
+
target_compile_options(midigpt PRIVATE -Wall -Wextra -Wpedantic)
|
87 |
+
endif()
|
CMakeSettings.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"configurations": [
|
3 |
+
{
|
4 |
+
"name": "x64-Debug",
|
5 |
+
"generator": "Ninja",
|
6 |
+
"configurationType": "Debug",
|
7 |
+
"inheritEnvironments": [ "msvc_x64_x64" ],
|
8 |
+
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
9 |
+
"installRoot": "${projectDir}\\out\\install\\${name}",
|
10 |
+
"cmakeCommandArgs": "",
|
11 |
+
"buildCommandArgs": "",
|
12 |
+
"ctestCommandArgs": ""
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"name": "WSL-GCC-Debug",
|
16 |
+
"generator": "Ninja",
|
17 |
+
"configurationType": "Debug",
|
18 |
+
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
19 |
+
"installRoot": "${projectDir}\\out\\install\\${name}",
|
20 |
+
"cmakeExecutable": "cmake",
|
21 |
+
"cmakeCommandArgs": "",
|
22 |
+
"buildCommandArgs": "",
|
23 |
+
"ctestCommandArgs": "",
|
24 |
+
"inheritEnvironments": [ "linux_x64" ],
|
25 |
+
"wslPath": "${defaultWSLPath}"
|
26 |
+
}
|
27 |
+
]
|
28 |
+
}
|
README.md
CHANGED
@@ -9,3 +9,179 @@ short_description: MIDI-GPT-inference-docker
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
+
|
13 |
+
|
14 |
+
[](https://metacreation.net/category/projects/)
|
15 |
+
|
16 |
+
# MIDI-GPT Guide
|
17 |
+
|
18 |
+
This is the repository for MIDI-GPT, a generative system based on the Transformer architecture that is designed for computer-assisted music composition workflows. This work was presented at the 39th Annual AAAI Conference in Philadelphia, USA in this [paper](https://arxiv.org/abs/2501.17011)
|
19 |
+
|
20 |
+
# Using MIDI-GPT
|
21 |
+
|
22 |
+
## Installation
|
23 |
+
|
24 |
+
To successfully install the midigpt python library, use the script ```midigpt_setup_helper.sh```. You may first download this script on its own and run it, which will clone the repository and build the library. Below is an example of the usage:
|
25 |
+
|
26 |
+
```sh
|
27 |
+
bash midigpt_setup_helper.sh -d midigpt_dir
|
28 |
+
```
|
29 |
+
|
30 |
+
>**Note:** Python 3.8 is required for the library
|
31 |
+
>**Note:** If you're building on mac, use the ```-m```argument.
|
32 |
+
|
33 |
+
## Inference
|
34 |
+
|
35 |
+
Once downloaded, MIDI-GPT is ready to use. ```python_scripts_for_testing/pythoninferencetest.py``` is an example of using MIDI-GPT. In summary, three objects need to be created before sampling:
|
36 |
+
- Piece: Load the MIDI file into a JSON representation of the MIDI piece
|
37 |
+
- Status: This dict indicates the sampling process that is desired (on which tracks, continuation/resampling/infilling, etc.) as well as attribute control values
|
38 |
+
- Param: This dict indicates sampling parameters such as temperature or number of generated bars per step
|
39 |
+
|
40 |
+
You must provide an input MIDI file, the checkpoint model file and an optional output MIDI file. Our model is provided in the ```models/model.zip``` file.
|
41 |
+
|
42 |
+
Then, using the ```midigpt``` Python API, call the sample function with these objects as arguments. After sampling, the result can then be converted and saved into a MIDI file.
|
43 |
+
|
44 |
+
# Training MIDI-GPT
|
45 |
+
|
46 |
+
Training the model was done on computing clusters on Compute Canada, therefore the training scripts are tailored to this platform but may easily be adapted to similar platforms. Training was done using the GigaMIDI dataset, first serialzed into a compressed file using ```create_dataset_compute_canada.sh``` and ```python_scripts/create_dataset.py```. The training was executed using the ```python_scripts/train.py```. Finally, the model weights file is converted from the training checkpoint using ```convert.py```.
|
47 |
+
|
48 |
+
If you're unfamiliar with Compute Canada, make sure to check the introductory .md [here]().
|
49 |
+
|
50 |
+
## Installation - Cedar and Niagara
|
51 |
+
0. You might want to allocate an interactive session with salloc:
|
52 |
+
|
53 |
+
>**Note:** You DON'T need to do this in Niagara.
|
54 |
+
|
55 |
+
```sh
|
56 |
+
salloc --time=3:0:0 --nodes 1 --cpus-per-task 32 --mem=128000 --account=user
|
57 |
+
```
|
58 |
+
|
59 |
+
1. First, make sure to clone the MMM_API into a folder in your CC machine:
|
60 |
+
```sh
|
61 |
+
https://github.com/Metacreation-Lab/MIDI-GPT
|
62 |
+
```
|
63 |
+
2. Then we must load the standard environments and some dependencies:
|
64 |
+
|
65 |
+
>**Note:** If you're building in Niagara, load this first:
|
66 |
+
```sh
|
67 |
+
module load CCEnv arch/avx512
|
68 |
+
```
|
69 |
+
Then proceed to load the rest (If you're in Cedar, start from here):
|
70 |
+
```sh
|
71 |
+
module load StdEnv/2020
|
72 |
+
module load cmake/3.23.1
|
73 |
+
module load gcc/11.3.0
|
74 |
+
module load protobuf/3.12.3
|
75 |
+
module load python/3.8.2
|
76 |
+
```
|
77 |
+
3. Then we must create an environment and activate it:
|
78 |
+
```sh
|
79 |
+
virtualenv --no-download ./ENV # ENV is the name of the environment
|
80 |
+
source ./ENV/bin/activate
|
81 |
+
pip install --no-index --upgrade pip
|
82 |
+
|
83 |
+
# For training only
|
84 |
+
pip install torch==1.13.0+computecanada
|
85 |
+
pip install transformers==4.26.1+computecanada
|
86 |
+
```
|
87 |
+
4. Finally, just call the bash script with the correct argument:
|
88 |
+
```sh
|
89 |
+
bash create_python_library.sh --test_build --compute_canada
|
90 |
+
```
|
91 |
+
Or if you are planning to just train the model, add the argument excluding to torch library required only for inference:
|
92 |
+
```sh
|
93 |
+
bash create_python_library.sh --no_torch --compute_canada
|
94 |
+
```
|
95 |
+
5. To test the library imports for training, run the train.py script by importing it:
|
96 |
+
```sh
|
97 |
+
cd python_scripts
|
98 |
+
python3 -c "import train"
|
99 |
+
```
|
100 |
+
> **Note:** A helper script ```midigpt_setup_helper.sh``` does all these steps autmoatically (for training or inference). Download it individually and run it where you wish to clone the repository.
|
101 |
+
> **Note:** If you run the code without the --test_build flag, it will still compile and create the python library but it won't test it with the current model in production.
|
102 |
+
> **Note:** The other flag (--compute_canada) is necesary to build the code properly.
|
103 |
+
|
104 |
+
That's it!
|
105 |
+
|
106 |
+
Everything should get installed correctly in your python environment! If you log out and back in to CC make sure to activate the environment in which you installed the API.
|
107 |
+
|
108 |
+
## Training
|
109 |
+
|
110 |
+
### Dataset Building
|
111 |
+
|
112 |
+
In order to train a new model, you must first build a dataset. You can upload the files you need using Globus (check the CC [guide]()).
|
113 |
+
|
114 |
+
> **Note**: Remember that to copy from the shared folder to your own folders you must use absolute paths.
|
115 |
+
|
116 |
+
The data should be organized in a way where all midi files are contained within three folders ```train```, ```test```, and ```valid```. Further directories can be used to organize the midi files as long as they are within these three directories.
|
117 |
+
|
118 |
+
If your dataset is a single folder containing all the midi files, we provide a helper script that automatically slits the dataset to 80%-10%-10%. Simply modify ```data_split.sh``` to match your cas and run.
|
119 |
+
|
120 |
+
Once you have the folder with the data, run the following command
|
121 |
+
```sh
|
122 |
+
sh create_dataset_compute_canada.sh --root_dir=<root_dir> --encoding=<encoding> --data_dir=<data_dir> --output=<output>
|
123 |
+
```
|
124 |
+
where:
|
125 |
+
- ```<root_dir>``` is the root folder where the midigpt repository folder is located
|
126 |
+
- ```<encoding>``` is the conder to use. We suggest using ```EXPRESSIVE_ENCODER```
|
127 |
+
- ```<data_dir>``` is the dataset folder containing the three ```train```, ```test```, and ```valid``` folders.
|
128 |
+
- ```<output>``` is the location of the ouptt ```.arr``` file. The resulting file while be ```<output>_NUM_BARS=<num_bars>_RESOLUTION_<resolution>.arr```
|
129 |
+
>**Note:** If you are on Compute Canada, we suggest you run these commands through an sbatch job as they can take some time.
|
130 |
+
|
131 |
+
### Training a Model
|
132 |
+
|
133 |
+
To train a model, run the train.py file. Different lab members have managed to set the paths differently. What works for me is to use global paths. An example would be:
|
134 |
+
```sh
|
135 |
+
python train.py --arch gpt2 --config /home/user/scratch/TRAINING-master/config/gpt2_tiny.json --encoding EXPRESSIVE_ENCODER --ngpu 4 --dataset /home/user/scratch/test_NUM_BARS=4_OPZ_False.arr --batch_size 32 --label DELETE_ME
|
136 |
+
```
|
137 |
+
|
138 |
+
### Running Jobs
|
139 |
+
|
140 |
+
To read the CC documentation, cick [here](https://docs.alliancecan.ca/wiki/Running_jobs). You can run small snippets of code to test things out without allocating any resources. However, to train a model or perform any time/resource consuming task, you must schedule a job. A list of different types of job scheduling will be added here.
|
141 |
+
|
142 |
+
#### Interactive Jobs
|
143 |
+
You can start an interactive session on a compute node with salloc.
|
144 |
+
```sh
|
145 |
+
salloc --time=3:0:0 --nodes 1 --cpus-per-task 32 --mem=128000 --account=user
|
146 |
+
```
|
147 |
+
|
148 |
+
#### Scheduled jobs (use this for training)
|
149 |
+
For time-expensive tasks it is better to create a bash file and submit a job with sbatch:
|
150 |
+
```sh
|
151 |
+
sbatch simple_job.sh
|
152 |
+
```
|
153 |
+
|
154 |
+
Here is an example of the contents of a bash file to submit a midigpt training job:
|
155 |
+
```sh
|
156 |
+
#!/bin/bash
|
157 |
+
#SBATCH --gres=gpu:v100l:4
|
158 |
+
#SBATCH --cpus-per-task=32
|
159 |
+
#SBATCH --exclusive
|
160 |
+
#SBATCH --mem=0
|
161 |
+
#SBATCH --time=2-23:00
|
162 |
+
#SBATCH --account=user
|
163 |
+
#SBATCH --mail-user [email protected] <---- MAKE SURE TO PUT YOUR EMAIL
|
164 |
+
#SBATCH --mail-type ALL
|
165 |
+
#SBATCH --output=CCLOG/FILENAME.out <---- MAKE SURE TO CHANGE THE NAME OF THE FILE
|
166 |
+
|
167 |
+
source $SCRATCH/PY_3610/bin/activate <---- THIS IS THE DIRECTORY TO THE ENV WHERE YOU HAVE THE midigpt_api INSTALLED
|
168 |
+
cd $SCRATCH/MMM_TRAINING-master
|
169 |
+
module load StdEnv/2020 protobuf python/3.6.10
|
170 |
+
source $SCRATCH/PY_3610/bin/activate <---- SAME HERE, MAKE SURE THE DIRECTORY IS PLACED CORRECTLY
|
171 |
+
python train.py --arch reformer --config /home/user/scratch/MMM_TRAINING-master/config/reformer.json --encoding EXPRESSIVE_ENCODER --ngpu 4 --dataset /home/user/scratch/dataset_NUM_BARS=4.arr --batch_size 32 --label DELETE_ME
|
172 |
+
```
|
173 |
+
|
174 |
+
In this case we are using 4 v1001 GPUs (**gres** argument) and we're asking for 2 days and 23 hours of time to run the job (**time** argument).
|
175 |
+
|
176 |
+
#### Check jobs and eliminate session
|
177 |
+
To show all the users
|
178 |
+
```sh
|
179 |
+
who -u
|
180 |
+
```
|
181 |
+
|
182 |
+
To kill all the sessions
|
183 |
+
```sh
|
184 |
+
pkill -u username
|
185 |
+
```
|
186 |
+
|
187 |
+
|
create_dataset_compute_canada.sh
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --cpus-per-task=32
|
3 |
+
#SBATCH --time=10:00:00
|
4 |
+
|
5 |
+
root_dir="" # Model directory to replace MODEL_NAME
|
6 |
+
data_dir="" # Data directory to replace DATA_DIR
|
7 |
+
encoding="" # Encoding to replace ENCODING
|
8 |
+
output="" # Output to replace OUTPUT
|
9 |
+
metadata=""
|
10 |
+
test="no"
|
11 |
+
zip="no"
|
12 |
+
res="12"
|
13 |
+
max="-1"
|
14 |
+
|
15 |
+
# Parse arguments
|
16 |
+
for arg in "$@"
|
17 |
+
do
|
18 |
+
case $arg in
|
19 |
+
--root_dir=*)
|
20 |
+
root_dir="${arg#*=}"
|
21 |
+
shift # Remove --root_dir= from processing
|
22 |
+
;;
|
23 |
+
--metadata=*)
|
24 |
+
metadata="${arg#*=}"
|
25 |
+
shift # Remove --metadata= from processing
|
26 |
+
;;
|
27 |
+
--res=*)
|
28 |
+
res="${arg#*=}"
|
29 |
+
shift # Remove --metadata= from processing
|
30 |
+
;;
|
31 |
+
--test=*)
|
32 |
+
test="${arg#*=}"
|
33 |
+
shift # Remove --metadata= from processing
|
34 |
+
;;
|
35 |
+
--data_dir=*)
|
36 |
+
data_dir="${arg#*=}"
|
37 |
+
shift # Remove --data_dir= from processing
|
38 |
+
;;
|
39 |
+
--encoding=*)
|
40 |
+
encoding="${arg#*=}"
|
41 |
+
shift # Remove --encoding= from processing
|
42 |
+
;;
|
43 |
+
--output=*)
|
44 |
+
output="${arg#*=}"
|
45 |
+
shift # Remove --output= from processing
|
46 |
+
;;
|
47 |
+
--zip=*)
|
48 |
+
zip="${arg#*=}"
|
49 |
+
shift # Remove --output= from processing
|
50 |
+
;;
|
51 |
+
--max=*)
|
52 |
+
max="${arg#*=}"
|
53 |
+
shift # Remove --output= from processing
|
54 |
+
;;
|
55 |
+
esac
|
56 |
+
done
|
57 |
+
|
58 |
+
module load CCEnv arch/avx512
|
59 |
+
module load StdEnv/2020
|
60 |
+
module load cmake/3.23.1
|
61 |
+
module load gcc/11.3.0
|
62 |
+
module load protobuf/3.12.3
|
63 |
+
module load python/3.8.2
|
64 |
+
|
65 |
+
mkdir -p $root_dir/CCLOG
|
66 |
+
source $root_dir/venv/bin/activate
|
67 |
+
|
68 |
+
cp $root_dir/MIDI-GPT/python_lib/midigpt.cpython-38-x86_64-linux-gnu.so $root_dir/MIDI-GPT/python_scripts
|
69 |
+
|
70 |
+
python3 $root_dir/MIDI-GPT/python_scripts/create_dataset.py --nthreads 40 --max_size $max --data_dir $data_dir --encoding $encoding --output $output --metadata $metadata --test $test --expressive --resolution $res
|
71 |
+
|
72 |
+
if [[ "$zip" == "yes" ]]
|
73 |
+
then
|
74 |
+
cd $output
|
75 |
+
cd ../
|
76 |
+
zip -r EXPRESSIVE_GIGAMIDI_24_1920.zip $output
|
77 |
+
fi
|
create_python_library.sh
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
compute_canada=false
|
4 |
+
mac_os=false
|
5 |
+
cpu=false
|
6 |
+
no_torch=false
|
7 |
+
niagara=false
|
8 |
+
env_name="venv"
|
9 |
+
parent_dir="midigpt_workspace" # default parent directory name
|
10 |
+
cuda=false
|
11 |
+
trace=false
|
12 |
+
|
13 |
+
|
14 |
+
# Parse arguments
|
15 |
+
for ((i=1;i<=$#;i++)); do
|
16 |
+
case ${!i} in
|
17 |
+
--trace)
|
18 |
+
trace=true
|
19 |
+
;;
|
20 |
+
--compute_canada)
|
21 |
+
compute_canada=true
|
22 |
+
;;
|
23 |
+
--cpu)
|
24 |
+
cpu=true
|
25 |
+
;;
|
26 |
+
--mac_os)
|
27 |
+
mac_os=true
|
28 |
+
;;
|
29 |
+
--no_torch)
|
30 |
+
no_torch=true
|
31 |
+
;;
|
32 |
+
--niagara)
|
33 |
+
niagara=true
|
34 |
+
;;
|
35 |
+
--env_name)
|
36 |
+
i=$((i+1))
|
37 |
+
env_name=${!i}
|
38 |
+
;;
|
39 |
+
-n=*|--name=*) # new parent directory name option
|
40 |
+
parent_dir="${!i#*=}"
|
41 |
+
shift
|
42 |
+
;;
|
43 |
+
--cuda)
|
44 |
+
if $no_torch; then
|
45 |
+
echo "Cannot use --cuda and --no_torch at the same time."
|
46 |
+
exit 1
|
47 |
+
fi
|
48 |
+
cuda=true
|
49 |
+
;;
|
50 |
+
*)
|
51 |
+
echo "Unknown option ${!i}"
|
52 |
+
exit 1
|
53 |
+
;;
|
54 |
+
esac
|
55 |
+
done
|
56 |
+
|
57 |
+
|
58 |
+
# Get the current directory name
|
59 |
+
dir_name=$(basename `pwd`)
|
60 |
+
|
61 |
+
# Get the parent directory name
|
62 |
+
parent_dir_name=$(basename $(dirname `pwd`))
|
63 |
+
|
64 |
+
if [ "$parent_dir_name" != "$parent_dir" ]; then
|
65 |
+
|
66 |
+
# Go to the parent directory
|
67 |
+
cd ..
|
68 |
+
|
69 |
+
# Create the new parent directory
|
70 |
+
mkdir $parent_dir
|
71 |
+
|
72 |
+
# Move the old directory into the new parent directory
|
73 |
+
mv $dir_name $parent_dir/
|
74 |
+
|
75 |
+
# Change to the old directory, which is now inside the parent directory
|
76 |
+
cd $parent_dir/$dir_name
|
77 |
+
fi
|
78 |
+
|
79 |
+
# Load modules if we are in compute_canada and niagara
|
80 |
+
if $compute_canada; then
|
81 |
+
if $niagara; then
|
82 |
+
module load CCEnv arch/avx512
|
83 |
+
fi
|
84 |
+
module load StdEnv/2020
|
85 |
+
module load cmake/3.23.1
|
86 |
+
module load gcc/11.3.0
|
87 |
+
module load protobuf/3.12.3
|
88 |
+
module load python/3.8.2
|
89 |
+
|
90 |
+
mkdir ../CCLOG
|
91 |
+
fi
|
92 |
+
|
93 |
+
# Environment creation
|
94 |
+
if [[ -n ../$env_name ]]; then
|
95 |
+
if [[ -d ../$env_name ]]; then
|
96 |
+
echo "Environment $env_name already exists, activating it..."
|
97 |
+
else
|
98 |
+
echo "Environment $env_name does not exist, creating it..."
|
99 |
+
if $compute_canada; then
|
100 |
+
virtualenv ../$env_name
|
101 |
+
else
|
102 |
+
python3 -m venv ../$env_name
|
103 |
+
fi
|
104 |
+
fi
|
105 |
+
fi
|
106 |
+
|
107 |
+
source ../$env_name/bin/activate
|
108 |
+
|
109 |
+
# Install requirements
|
110 |
+
pip install -r pip_requirements/common_requirements.txt
|
111 |
+
|
112 |
+
if $compute_canada; then
|
113 |
+
pip install -r pip_requirements/create_dataset_requirements.txt
|
114 |
+
fi
|
115 |
+
if $mac_os; then
|
116 |
+
pip install -r pip_requirements/inference_requirements.txt
|
117 |
+
fi
|
118 |
+
|
119 |
+
if $compute_canada && ! $niagara; then # anf if no torch
|
120 |
+
pip install -r pip_requirements/train_requirements.txt
|
121 |
+
fi
|
122 |
+
|
123 |
+
#deactivate
|
124 |
+
|
125 |
+
# Set CMake flags based on command line arguments
|
126 |
+
cmake_flags=""
|
127 |
+
if $compute_canada; then
|
128 |
+
cmake_flags="$cmake_flags -Dcompute_canada=ON"
|
129 |
+
fi
|
130 |
+
|
131 |
+
if $no_torch; then
|
132 |
+
cmake_flags="$cmake_flags -Dno_torch=ON"
|
133 |
+
fi
|
134 |
+
|
135 |
+
if $trace; then
|
136 |
+
cmake_flags="$cmake_flags -Dtrace=ON"
|
137 |
+
fi
|
138 |
+
|
139 |
+
# Code to check if libtorch and pybind11 are already downloaded
|
140 |
+
if ! $no_torch; then
|
141 |
+
libtorch_path="libraries/libtorch"
|
142 |
+
libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip"
|
143 |
+
if $cuda; then
|
144 |
+
libtorch_url="https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip"
|
145 |
+
fi
|
146 |
+
fi
|
147 |
+
|
148 |
+
pybind11_path="libraries/pybind11"
|
149 |
+
midifile_path="libraries/midifile"
|
150 |
+
|
151 |
+
|
152 |
+
pybind11_url="https://github.com/pybind/pybind11.git"
|
153 |
+
midifile_url="https://github.com/craigsapp/midifile"
|
154 |
+
|
155 |
+
if ! $no_torch; then
|
156 |
+
if $mac_os; then
|
157 |
+
libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.0.1.zip"
|
158 |
+
fi
|
159 |
+
|
160 |
+
if $cpu; then
|
161 |
+
libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip"
|
162 |
+
fi
|
163 |
+
|
164 |
+
# Check if libtorch folder exists and is not empty
|
165 |
+
if [ ! -d "$libtorch_path" ] || [ -z "$(ls -A "$libtorch_path")" ]; then
|
166 |
+
echo "libtorch folder does not exist or is empty. Downloading and extracting..."
|
167 |
+
mkdir -p "$libtorch_path"
|
168 |
+
curl -L "$libtorch_url" -o libtorch.zip
|
169 |
+
unzip -q libtorch.zip -d libraries/
|
170 |
+
rm libtorch.zip
|
171 |
+
echo "libtorch downloaded and extracted."
|
172 |
+
else
|
173 |
+
echo "libtorch folder exists and is not empty. No need to download."
|
174 |
+
fi
|
175 |
+
fi
|
176 |
+
|
177 |
+
# Check if pybind11 folder exists and is not empty
|
178 |
+
if [ ! -d "$pybind11_path" ] || [ -z "$(ls -A "$pybind11_path")" ]; then
|
179 |
+
echo "pybind11 folder does not exist or is empty. Cloning the repository..."
|
180 |
+
mkdir -p libraries
|
181 |
+
git clone "$pybind11_url" "$pybind11_path"
|
182 |
+
echo "pybind11 downloaded."
|
183 |
+
cd libraries/pybind11
|
184 |
+
git reset --hard 5ccb9e4
|
185 |
+
cd ../../
|
186 |
+
echo "pybind11 reset to working build"
|
187 |
+
else
|
188 |
+
echo "pybind11 folder exists and is not empty. No need to download."
|
189 |
+
fi
|
190 |
+
|
191 |
+
# Check if midifile folder exists and is not empty
|
192 |
+
if [ ! -d "$midifile_path" ] || [ -z "$(ls -A "$midifile_path")" ]; then
|
193 |
+
echo "midifile folder does not exist or is empty. Cloning the repository..."
|
194 |
+
mkdir -p libraries
|
195 |
+
git clone "$midifile_url" "$midifile_path"
|
196 |
+
echo "midifile downloaded."
|
197 |
+
cd libraries/midifile
|
198 |
+
git reset --hard 838c62c
|
199 |
+
cd ../../
|
200 |
+
echo "midifile reset to working build"
|
201 |
+
else
|
202 |
+
echo "midifile folder exists and is not empty. No need to download."
|
203 |
+
fi
|
204 |
+
|
205 |
+
# Middle section of the script to build the python library
|
206 |
+
rm -rf ./python_lib
|
207 |
+
mkdir ./python_lib
|
208 |
+
rm -rf ./libraries/protobuf/build
|
209 |
+
mkdir ./libraries/protobuf/build
|
210 |
+
|
211 |
+
cd ./libraries/protobuf/src
|
212 |
+
protoc --cpp_out ../build *.proto
|
213 |
+
cd ../../..
|
214 |
+
|
215 |
+
cd ./python_lib
|
216 |
+
|
217 |
+
if $mac_os; then
|
218 |
+
cmake $cmake_flags .. -Dmac_os=ON -DCMAKE_PREFIX_PATH=$(python3 -c 'import torch;print(torch.utils.cmake_prefix_path)')
|
219 |
+
else
|
220 |
+
cmake $cmake_flags ..
|
221 |
+
fi
|
222 |
+
make
|
223 |
+
python3 -c "import midigpt; print('midigpt python library built successfully')"
|
224 |
+
|
225 |
+
cd ..
|
226 |
+
if $compute_canada; then
|
227 |
+
dos2unix create_dataset_compute_canada.sh
|
228 |
+
dos2unix train_dataset.sh
|
229 |
+
fi
|
230 |
+
cd ./python_lib
|
231 |
+
|
232 |
+
cd ..
|
data_split.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
root=$SCRATCH/datasets/GigaMIDI_Cleaned/Cleaned_Ver_EP_Class-GigaMIDI/Cleaned_GigaMIDI
|
4 |
+
new_root=$SCRATCH/datasets/GigaMIDI_Cleaned/Cleaned_Ver_EP_Class-GigaMIDI/Cleaned_GigaMIDI_Split
|
5 |
+
parent=$SCRATCH/workspace_train/parent_dir
|
6 |
+
|
7 |
+
mkdir -p $new_root
|
8 |
+
cd $new_root
|
9 |
+
mkdir -p train
|
10 |
+
mkdir -p test
|
11 |
+
mkdir -p valid
|
12 |
+
|
13 |
+
cd $SCRATCH
|
14 |
+
|
15 |
+
source $parent/venv/bin/activate
|
16 |
+
python $parent/MIDI-GPT/python_scripts/data_split.py $root $new_root
|
include/dataset_creation/compression/lz4.h
ADDED
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* LZ4 - Fast LZ compression algorithm
|
3 |
+
* Header File
|
4 |
+
* Copyright (C) 2011-present, Yann Collet.
|
5 |
+
|
6 |
+
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
|
7 |
+
|
8 |
+
Redistribution and use in source and binary forms, with or without
|
9 |
+
modification, are permitted provided that the following conditions are
|
10 |
+
met:
|
11 |
+
|
12 |
+
* Redistributions of source code must retain the above copyright
|
13 |
+
notice, this list of conditions and the following disclaimer.
|
14 |
+
* Redistributions in binary form must reproduce the above
|
15 |
+
copyright notice, this list of conditions and the following disclaimer
|
16 |
+
in the documentation and/or other materials provided with the
|
17 |
+
distribution.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
20 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
21 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
22 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
23 |
+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
24 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
25 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
26 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
27 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
28 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
You can contact the author at :
|
32 |
+
- LZ4 homepage : http://www.lz4.org
|
33 |
+
- LZ4 source repository : https://github.com/lz4/lz4
|
34 |
+
*/
|
35 |
+
#if defined (__cplusplus)
|
36 |
+
extern "C" {
|
37 |
+
#endif
|
38 |
+
|
39 |
+
#ifndef LZ4_H_2983827168210
|
40 |
+
#define LZ4_H_2983827168210
|
41 |
+
|
42 |
+
/* --- Dependency --- */
|
43 |
+
#include <stddef.h> /* size_t */
|
44 |
+
|
45 |
+
|
46 |
+
/**
|
47 |
+
Introduction
|
48 |
+
|
49 |
+
LZ4 is lossless compression algorithm, providing compression speed >500 MB/s per core,
|
50 |
+
scalable with multi-cores CPU. It features an extremely fast decoder, with speed in
|
51 |
+
multiple GB/s per core, typically reaching RAM speed limits on multi-core systems.
|
52 |
+
|
53 |
+
The LZ4 compression library provides in-memory compression and decompression functions.
|
54 |
+
It gives full buffer control to user.
|
55 |
+
Compression can be done in:
|
56 |
+
- a single step (described as Simple Functions)
|
57 |
+
- a single step, reusing a context (described in Advanced Functions)
|
58 |
+
- unbounded multiple steps (described as Streaming compression)
|
59 |
+
|
60 |
+
lz4.h generates and decodes LZ4-compressed blocks (doc/lz4_Block_format.md).
|
61 |
+
Decompressing such a compressed block requires additional metadata.
|
62 |
+
Exact metadata depends on exact decompression function.
|
63 |
+
For the typical case of LZ4_decompress_safe(),
|
64 |
+
metadata includes block's compressed size, and maximum bound of decompressed size.
|
65 |
+
Each application is free to encode and pass such metadata in whichever way it wants.
|
66 |
+
|
67 |
+
lz4.h only handle blocks, it can not generate Frames.
|
68 |
+
|
69 |
+
Blocks are different from Frames (doc/lz4_Frame_format.md).
|
70 |
+
Frames bundle both blocks and metadata in a specified manner.
|
71 |
+
Embedding metadata is required for compressed data to be self-contained and portable.
|
72 |
+
Frame format is delivered through a companion API, declared in lz4frame.h.
|
73 |
+
The `lz4` CLI can only manage frames.
|
74 |
+
*/
|
75 |
+
|
76 |
+
/*^***************************************************************
|
77 |
+
* Export parameters
|
78 |
+
*****************************************************************/
|
79 |
+
/*
|
80 |
+
* LZ4_DLL_EXPORT :
|
81 |
+
* Enable exporting of functions when building a Windows DLL
|
82 |
+
* LZ4LIB_VISIBILITY :
|
83 |
+
* Control library symbols visibility.
|
84 |
+
*/
|
85 |
+
#ifndef LZ4LIB_VISIBILITY
|
86 |
+
# if defined(__GNUC__) && (__GNUC__ >= 4)
|
87 |
+
# define LZ4LIB_VISIBILITY __attribute__ ((visibility ("default")))
|
88 |
+
# else
|
89 |
+
# define LZ4LIB_VISIBILITY
|
90 |
+
# endif
|
91 |
+
#endif
|
92 |
+
#if defined(LZ4_DLL_EXPORT) && (LZ4_DLL_EXPORT==1)
|
93 |
+
# define LZ4LIB_API __declspec(dllexport) LZ4LIB_VISIBILITY
|
94 |
+
#elif defined(LZ4_DLL_IMPORT) && (LZ4_DLL_IMPORT==1)
|
95 |
+
# define LZ4LIB_API __declspec(dllimport) LZ4LIB_VISIBILITY /* It isn't required but allows to generate better code, saving a function pointer load from the IAT and an indirect jump.*/
|
96 |
+
#else
|
97 |
+
# define LZ4LIB_API LZ4LIB_VISIBILITY
|
98 |
+
#endif
|
99 |
+
|
100 |
+
/*------ Version ------*/
|
101 |
+
#define LZ4_VERSION_MAJOR 1 /* for breaking interface changes */
|
102 |
+
#define LZ4_VERSION_MINOR 9 /* for new (non-breaking) interface capabilities */
|
103 |
+
#define LZ4_VERSION_RELEASE 2 /* for tweaks, bug-fixes, or development */
|
104 |
+
|
105 |
+
#define LZ4_VERSION_NUMBER (LZ4_VERSION_MAJOR *100*100 + LZ4_VERSION_MINOR *100 + LZ4_VERSION_RELEASE)
|
106 |
+
|
107 |
+
#define LZ4_LIB_VERSION LZ4_VERSION_MAJOR.LZ4_VERSION_MINOR.LZ4_VERSION_RELEASE
|
108 |
+
#define LZ4_QUOTE(str) #str
|
109 |
+
#define LZ4_EXPAND_AND_QUOTE(str) LZ4_QUOTE(str)
|
110 |
+
#define LZ4_VERSION_STRING LZ4_EXPAND_AND_QUOTE(LZ4_LIB_VERSION)
|
111 |
+
|
112 |
+
LZ4LIB_API int LZ4_versionNumber (void); /**< library version number; useful to check dll version */
|
113 |
+
LZ4LIB_API const char* LZ4_versionString (void); /**< library version std::string; useful to check dll version */
|
114 |
+
|
115 |
+
|
116 |
+
/*-************************************
|
117 |
+
* Tuning parameter
|
118 |
+
**************************************/
|
119 |
+
/*!
|
120 |
+
* LZ4_MEMORY_USAGE :
|
121 |
+
* Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.)
|
122 |
+
* Increasing memory usage improves compression ratio.
|
123 |
+
* Reduced memory usage may improve speed, thanks to better cache locality.
|
124 |
+
* Default value is 14, for 16KB, which nicely fits into Intel x86 L1 cache
|
125 |
+
*/
|
126 |
+
#ifndef LZ4_MEMORY_USAGE
|
127 |
+
# define LZ4_MEMORY_USAGE 14
|
128 |
+
#endif
|
129 |
+
|
130 |
+
|
131 |
+
/*-************************************
|
132 |
+
* Simple Functions
|
133 |
+
**************************************/
|
134 |
+
/*! LZ4_compress_default() :
|
135 |
+
* Compresses 'srcSize' bytes from buffer 'src'
|
136 |
+
* into already allocated 'dst' buffer of size 'dstCapacity'.
|
137 |
+
* Compression is guaranteed to succeed if 'dstCapacity' >= LZ4_compressBound(srcSize).
|
138 |
+
* It also runs faster, so it's a recommended setting.
|
139 |
+
* If the function cannot compress 'src' into a more limited 'dst' budget,
|
140 |
+
* compression stops *immediately*, and the function result is zero.
|
141 |
+
* In which case, 'dst' content is undefined (invalid).
|
142 |
+
* srcSize : max supported value is LZ4_MAX_INPUT_SIZE.
|
143 |
+
* dstCapacity : size of buffer 'dst' (which must be already allocated)
|
144 |
+
* @return : the number of bytes written into buffer 'dst' (necessarily <= dstCapacity)
|
145 |
+
* or 0 if compression fails
|
146 |
+
* Note : This function is protected against buffer overflow scenarios (never writes outside 'dst' buffer, nor read outside 'source' buffer).
|
147 |
+
*/
|
148 |
+
LZ4LIB_API int LZ4_compress_default(const char* src, char* dst, int srcSize, int dstCapacity);
|
149 |
+
|
150 |
+
/*! LZ4_decompress_safe() :
|
151 |
+
* compressedSize : is the exact complete size of the compressed block.
|
152 |
+
* dstCapacity : is the size of destination buffer (which must be already allocated), presumed an upper bound of decompressed size.
|
153 |
+
* @return : the number of bytes decompressed into destination buffer (necessarily <= dstCapacity)
|
154 |
+
* If destination buffer is not large enough, decoding will stop and output an error code (negative value).
|
155 |
+
* If the source stream is detected malformed, the function will stop decoding and return a negative result.
|
156 |
+
* Note 1 : This function is protected against malicious data packets :
|
157 |
+
* it will never writes outside 'dst' buffer, nor read outside 'source' buffer,
|
158 |
+
* even if the compressed block is maliciously modified to order the decoder to do these actions.
|
159 |
+
* In such case, the decoder stops immediately, and considers the compressed block malformed.
|
160 |
+
* Note 2 : compressedSize and dstCapacity must be provided to the function, the compressed block does not contain them.
|
161 |
+
* The implementation is free to send / store / derive this information in whichever way is most beneficial.
|
162 |
+
* If there is a need for a different format which bundles together both compressed data and its metadata, consider looking at lz4frame.h instead.
|
163 |
+
*/
|
164 |
+
LZ4LIB_API int LZ4_decompress_safe (const char* src, char* dst, int compressedSize, int dstCapacity);
|
165 |
+
|
166 |
+
|
167 |
+
/*-************************************
|
168 |
+
* Advanced Functions
|
169 |
+
**************************************/
|
170 |
+
#define LZ4_MAX_INPUT_SIZE 0x7E000000 /* 2 113 929 216 bytes */
|
171 |
+
#define LZ4_COMPRESSBOUND(isize) ((unsigned)(isize) > (unsigned)LZ4_MAX_INPUT_SIZE ? 0 : (isize) + ((isize)/255) + 16)
|
172 |
+
|
173 |
+
/*! LZ4_compressBound() :
|
174 |
+
Provides the maximum size that LZ4 compression may output in a "worst case" scenario (input data not compressible)
|
175 |
+
This function is primarily useful for memory allocation purposes (destination buffer size).
|
176 |
+
Macro LZ4_COMPRESSBOUND() is also provided for compilation-time evaluation (stack memory allocation for example).
|
177 |
+
Note that LZ4_compress_default() compresses faster when dstCapacity is >= LZ4_compressBound(srcSize)
|
178 |
+
inputSize : max supported value is LZ4_MAX_INPUT_SIZE
|
179 |
+
return : maximum output size in a "worst case" scenario
|
180 |
+
or 0, if input size is incorrect (too large or negative)
|
181 |
+
*/
|
182 |
+
LZ4LIB_API int LZ4_compressBound(int inputSize);
|
183 |
+
|
184 |
+
/*! LZ4_compress_fast() :
|
185 |
+
Same as LZ4_compress_default(), but allows selection of "acceleration" factor.
|
186 |
+
The larger the acceleration value, the faster the algorithm, but also the lesser the compression.
|
187 |
+
It's a trade-off. It can be fine tuned, with each successive value providing roughly +~3% to speed.
|
188 |
+
An acceleration value of "1" is the same as regular LZ4_compress_default()
|
189 |
+
Values <= 0 will be replaced by ACCELERATION_DEFAULT (currently == 1, see lz4.c).
|
190 |
+
*/
|
191 |
+
LZ4LIB_API int LZ4_compress_fast (const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
|
192 |
+
|
193 |
+
|
194 |
+
/*! LZ4_compress_fast_extState() :
|
195 |
+
* Same as LZ4_compress_fast(), using an externally allocated memory space for its state.
|
196 |
+
* Use LZ4_sizeofState() to know how much memory must be allocated,
|
197 |
+
* and allocate it on 8-bytes boundaries (using `malloc()` typically).
|
198 |
+
* Then, provide this buffer as `void* state` to compression function.
|
199 |
+
*/
|
200 |
+
LZ4LIB_API int LZ4_sizeofState(void);
|
201 |
+
LZ4LIB_API int LZ4_compress_fast_extState (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
|
202 |
+
|
203 |
+
|
204 |
+
/*! LZ4_compress_destSize() :
|
205 |
+
* Reverse the logic : compresses as much data as possible from 'src' buffer
|
206 |
+
* into already allocated buffer 'dst', of size >= 'targetDestSize'.
|
207 |
+
* This function either compresses the entire 'src' content into 'dst' if it's large enough,
|
208 |
+
* or fill 'dst' buffer completely with as much data as possible from 'src'.
|
209 |
+
* note: acceleration parameter is fixed to "default".
|
210 |
+
*
|
211 |
+
* *srcSizePtr : will be modified to indicate how many bytes where read from 'src' to fill 'dst'.
|
212 |
+
* New value is necessarily <= input value.
|
213 |
+
* @return : Nb bytes written into 'dst' (necessarily <= targetDestSize)
|
214 |
+
* or 0 if compression fails.
|
215 |
+
*/
|
216 |
+
LZ4LIB_API int LZ4_compress_destSize (const char* src, char* dst, int* srcSizePtr, int targetDstSize);
|
217 |
+
|
218 |
+
|
219 |
+
/*! LZ4_decompress_safe_partial() :
|
220 |
+
* Decompress an LZ4 compressed block, of size 'srcSize' at position 'src',
|
221 |
+
* into destination buffer 'dst' of size 'dstCapacity'.
|
222 |
+
* Up to 'targetOutputSize' bytes will be decoded.
|
223 |
+
* The function stops decoding on reaching this objective,
|
224 |
+
* which can boost performance when only the beginning of a block is required.
|
225 |
+
*
|
226 |
+
* @return : the number of bytes decoded in `dst` (necessarily <= dstCapacity)
|
227 |
+
* If source stream is detected malformed, function returns a negative result.
|
228 |
+
*
|
229 |
+
* Note : @return can be < targetOutputSize, if compressed block contains less data.
|
230 |
+
*
|
231 |
+
* Note 2 : this function features 2 parameters, targetOutputSize and dstCapacity,
|
232 |
+
* and expects targetOutputSize <= dstCapacity.
|
233 |
+
* It effectively stops decoding on reaching targetOutputSize,
|
234 |
+
* so dstCapacity is kind of redundant.
|
235 |
+
* This is because in a previous version of this function,
|
236 |
+
* decoding operation would not "break" a sequence in the middle.
|
237 |
+
* As a consequence, there was no guarantee that decoding would stop at exactly targetOutputSize,
|
238 |
+
* it could write more bytes, though only up to dstCapacity.
|
239 |
+
* Some "margin" used to be required for this operation to work properly.
|
240 |
+
* This is no longer necessary.
|
241 |
+
* The function nonetheless keeps its signature, in an effort to not break API.
|
242 |
+
*/
|
243 |
+
LZ4LIB_API int LZ4_decompress_safe_partial (const char* src, char* dst, int srcSize, int targetOutputSize, int dstCapacity);
|
244 |
+
|
245 |
+
|
246 |
+
/*-*********************************************
|
247 |
+
* Streaming Compression Functions
|
248 |
+
***********************************************/
|
249 |
+
typedef union LZ4_stream_u LZ4_stream_t; /* incomplete type (defined later) */
|
250 |
+
|
251 |
+
LZ4LIB_API LZ4_stream_t* LZ4_createStream(void);
|
252 |
+
LZ4LIB_API int LZ4_freeStream (LZ4_stream_t* streamPtr);
|
253 |
+
|
254 |
+
/*! LZ4_resetStream_fast() : v1.9.0+
|
255 |
+
* Use this to prepare an LZ4_stream_t for a new chain of dependent blocks
|
256 |
+
* (e.g., LZ4_compress_fast_continue()).
|
257 |
+
*
|
258 |
+
* An LZ4_stream_t must be initialized once before usage.
|
259 |
+
* This is automatically done when created by LZ4_createStream().
|
260 |
+
* However, should the LZ4_stream_t be simply declared on stack (for example),
|
261 |
+
* it's necessary to initialize it first, using LZ4_initStream().
|
262 |
+
*
|
263 |
+
* After init, start any new stream with LZ4_resetStream_fast().
|
264 |
+
* A same LZ4_stream_t can be re-used multiple times consecutively
|
265 |
+
* and compress multiple streams,
|
266 |
+
* provided that it starts each new stream with LZ4_resetStream_fast().
|
267 |
+
*
|
268 |
+
* LZ4_resetStream_fast() is much faster than LZ4_initStream(),
|
269 |
+
* but is not compatible with memory regions containing garbage data.
|
270 |
+
*
|
271 |
+
* Note: it's only useful to call LZ4_resetStream_fast()
|
272 |
+
* in the context of streaming compression.
|
273 |
+
* The *extState* functions perform their own resets.
|
274 |
+
* Invoking LZ4_resetStream_fast() before is redundant, and even counterproductive.
|
275 |
+
*/
|
276 |
+
LZ4LIB_API void LZ4_resetStream_fast (LZ4_stream_t* streamPtr);
|
277 |
+
|
278 |
+
/*! LZ4_loadDict() :
|
279 |
+
* Use this function to reference a static dictionary into LZ4_stream_t.
|
280 |
+
* The dictionary must remain available during compression.
|
281 |
+
* LZ4_loadDict() triggers a reset, so any previous data will be forgotten.
|
282 |
+
* The same dictionary will have to be loaded on decompression side for successful decoding.
|
283 |
+
* Dictionary are useful for better compression of small data (KB range).
|
284 |
+
* While LZ4 accept any input as dictionary,
|
285 |
+
* results are generally better when using Zstandard's Dictionary Builder.
|
286 |
+
* Loading a size of 0 is allowed, and is the same as reset.
|
287 |
+
* @return : loaded dictionary size, in bytes (necessarily <= 64 KB)
|
288 |
+
*/
|
289 |
+
LZ4LIB_API int LZ4_loadDict (LZ4_stream_t* streamPtr, const char* dictionary, int dictSize);
|
290 |
+
|
291 |
+
/*! LZ4_compress_fast_continue() :
|
292 |
+
* Compress 'src' content using data from previously compressed blocks, for better compression ratio.
|
293 |
+
* 'dst' buffer must be already allocated.
|
294 |
+
* If dstCapacity >= LZ4_compressBound(srcSize), compression is guaranteed to succeed, and runs faster.
|
295 |
+
*
|
296 |
+
* @return : size of compressed block
|
297 |
+
* or 0 if there is an error (typically, cannot fit into 'dst').
|
298 |
+
*
|
299 |
+
* Note 1 : Each invocation to LZ4_compress_fast_continue() generates a new block.
|
300 |
+
* Each block has precise boundaries.
|
301 |
+
* Each block must be decompressed separately, calling LZ4_decompress_*() with relevant metadata.
|
302 |
+
* It's not possible to append blocks together and expect a single invocation of LZ4_decompress_*() to decompress them together.
|
303 |
+
*
|
304 |
+
* Note 2 : The previous 64KB of source data is __assumed__ to remain present, unmodified, at same address in memory !
|
305 |
+
*
|
306 |
+
* Note 3 : When input is structured as a double-buffer, each buffer can have any size, including < 64 KB.
|
307 |
+
* Make sure that buffers are separated, by at least one byte.
|
308 |
+
* This construction ensures that each block only depends on previous block.
|
309 |
+
*
|
310 |
+
* Note 4 : If input buffer is a ring-buffer, it can have any size, including < 64 KB.
|
311 |
+
*
|
312 |
+
* Note 5 : After an error, the stream status is undefined (invalid), it can only be reset or freed.
|
313 |
+
*/
|
314 |
+
LZ4LIB_API int LZ4_compress_fast_continue (LZ4_stream_t* streamPtr, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
|
315 |
+
|
316 |
+
/*! LZ4_saveDict() :
|
317 |
+
* If last 64KB data cannot be guaranteed to remain available at its current memory location,
|
318 |
+
* save it into a safer place (char* safeBuffer).
|
319 |
+
* This is schematically equivalent to a memcpy() followed by LZ4_loadDict(),
|
320 |
+
* but is much faster, because LZ4_saveDict() doesn't need to rebuild tables.
|
321 |
+
* @return : saved dictionary size in bytes (necessarily <= maxDictSize), or 0 if error.
|
322 |
+
*/
|
323 |
+
LZ4LIB_API int LZ4_saveDict (LZ4_stream_t* streamPtr, char* safeBuffer, int maxDictSize);
|
324 |
+
|
325 |
+
|
326 |
+
/*-**********************************************
|
327 |
+
* Streaming Decompression Functions
|
328 |
+
* Bufferless synchronous API
|
329 |
+
************************************************/
|
330 |
+
typedef union LZ4_streamDecode_u LZ4_streamDecode_t; /* tracking context */
|
331 |
+
|
332 |
+
/*! LZ4_createStreamDecode() and LZ4_freeStreamDecode() :
|
333 |
+
* creation / destruction of streaming decompression tracking context.
|
334 |
+
* A tracking context can be re-used multiple times.
|
335 |
+
*/
|
336 |
+
LZ4LIB_API LZ4_streamDecode_t* LZ4_createStreamDecode(void);
|
337 |
+
LZ4LIB_API int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream);
|
338 |
+
|
339 |
+
/*! LZ4_setStreamDecode() :
|
340 |
+
* An LZ4_streamDecode_t context can be allocated once and re-used multiple times.
|
341 |
+
* Use this function to start decompression of a new stream of blocks.
|
342 |
+
* A dictionary can optionally be set. Use NULL or size 0 for a reset order.
|
343 |
+
* Dictionary is presumed stable : it must remain accessible and unmodified during next decompression.
|
344 |
+
* @return : 1 if OK, 0 if error
|
345 |
+
*/
|
346 |
+
LZ4LIB_API int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize);
|
347 |
+
|
348 |
+
/*! LZ4_decoderRingBufferSize() : v1.8.2+
|
349 |
+
* Note : in a ring buffer scenario (optional),
|
350 |
+
* blocks are presumed decompressed next to each other
|
351 |
+
* up to the moment there is not enough remaining space for next block (remainingSize < maxBlockSize),
|
352 |
+
* at which stage it resumes from beginning of ring buffer.
|
353 |
+
* When setting such a ring buffer for streaming decompression,
|
354 |
+
* provides the minimum size of this ring buffer
|
355 |
+
* to be compatible with any source respecting maxBlockSize condition.
|
356 |
+
* @return : minimum ring buffer size,
|
357 |
+
* or 0 if there is an error (invalid maxBlockSize).
|
358 |
+
*/
|
359 |
+
LZ4LIB_API int LZ4_decoderRingBufferSize(int maxBlockSize);
|
360 |
+
#define LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize) (65536 + 14 + (maxBlockSize)) /* for static allocation; maxBlockSize presumed valid */
|
361 |
+
|
362 |
+
/*! LZ4_decompress_*_continue() :
|
363 |
+
* These decoding functions allow decompression of consecutive blocks in "streaming" mode.
|
364 |
+
* A block is an unsplittable entity, it must be presented entirely to a decompression function.
|
365 |
+
* Decompression functions only accepts one block at a time.
|
366 |
+
* The last 64KB of previously decoded data *must* remain available and unmodified at the memory position where they were decoded.
|
367 |
+
* If less than 64KB of data has been decoded, all the data must be present.
|
368 |
+
*
|
369 |
+
* Special : if decompression side sets a ring buffer, it must respect one of the following conditions :
|
370 |
+
* - Decompression buffer size is _at least_ LZ4_decoderRingBufferSize(maxBlockSize).
|
371 |
+
* maxBlockSize is the maximum size of any single block. It can have any value > 16 bytes.
|
372 |
+
* In which case, encoding and decoding buffers do not need to be synchronized.
|
373 |
+
* Actually, data can be produced by any source compliant with LZ4 format specification, and respecting maxBlockSize.
|
374 |
+
* - Synchronized mode :
|
375 |
+
* Decompression buffer size is _exactly_ the same as compression buffer size,
|
376 |
+
* and follows exactly same update rule (block boundaries at same positions),
|
377 |
+
* and decoding function is provided with exact decompressed size of each block (exception for last block of the stream),
|
378 |
+
* _then_ decoding & encoding ring buffer can have any size, including small ones ( < 64 KB).
|
379 |
+
* - Decompression buffer is larger than encoding buffer, by a minimum of maxBlockSize more bytes.
|
380 |
+
* In which case, encoding and decoding buffers do not need to be synchronized,
|
381 |
+
* and encoding ring buffer can have any size, including small ones ( < 64 KB).
|
382 |
+
*
|
383 |
+
* Whenever these conditions are not possible,
|
384 |
+
* save the last 64KB of decoded data into a safe buffer where it can't be modified during decompression,
|
385 |
+
* then indicate where this data is saved using LZ4_setStreamDecode(), before decompressing next block.
|
386 |
+
*/
|
387 |
+
LZ4LIB_API int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int srcSize, int dstCapacity);
|
388 |
+
|
389 |
+
|
390 |
+
/*! LZ4_decompress_*_usingDict() :
|
391 |
+
* These decoding functions work the same as
|
392 |
+
* a combination of LZ4_setStreamDecode() followed by LZ4_decompress_*_continue()
|
393 |
+
* They are stand-alone, and don't need an LZ4_streamDecode_t structure.
|
394 |
+
* Dictionary is presumed stable : it must remain accessible and unmodified during decompression.
|
395 |
+
* Performance tip : Decompression speed can be substantially increased
|
396 |
+
* when dst == dictStart + dictSize.
|
397 |
+
*/
|
398 |
+
LZ4LIB_API int LZ4_decompress_safe_usingDict (const char* src, char* dst, int srcSize, int dstCapcity, const char* dictStart, int dictSize);
|
399 |
+
|
400 |
+
#endif /* LZ4_H_2983827168210 */
|
401 |
+
|
402 |
+
|
403 |
+
/*^*************************************
|
404 |
+
* !!!!!! STATIC LINKING ONLY !!!!!!
|
405 |
+
***************************************/
|
406 |
+
|
407 |
+
/*-****************************************************************************
|
408 |
+
* Experimental section
|
409 |
+
*
|
410 |
+
* Symbols declared in this section must be considered unstable. Their
|
411 |
+
* signatures or semantics may change, or they may be removed altogether in the
|
412 |
+
* future. They are therefore only safe to depend on when the caller is
|
413 |
+
* statically linked against the library.
|
414 |
+
*
|
415 |
+
* To protect against unsafe usage, not only are the declarations guarded,
|
416 |
+
* the definitions are hidden by default
|
417 |
+
* when building LZ4 as a shared/dynamic library.
|
418 |
+
*
|
419 |
+
* In order to access these declarations,
|
420 |
+
* define LZ4_STATIC_LINKING_ONLY in your application
|
421 |
+
* before including LZ4's headers.
|
422 |
+
*
|
423 |
+
* In order to make their implementations accessible dynamically, you must
|
424 |
+
* define LZ4_PUBLISH_STATIC_FUNCTIONS when building the LZ4 library.
|
425 |
+
******************************************************************************/
|
426 |
+
|
427 |
+
#ifdef LZ4_STATIC_LINKING_ONLY
|
428 |
+
|
429 |
+
#ifndef LZ4_STATIC_3504398509
|
430 |
+
#define LZ4_STATIC_3504398509
|
431 |
+
|
432 |
+
#ifdef LZ4_PUBLISH_STATIC_FUNCTIONS
|
433 |
+
#define LZ4LIB_STATIC_API LZ4LIB_API
|
434 |
+
#else
|
435 |
+
#define LZ4LIB_STATIC_API
|
436 |
+
#endif
|
437 |
+
|
438 |
+
|
439 |
+
/*! LZ4_compress_fast_extState_fastReset() :
|
440 |
+
* A variant of LZ4_compress_fast_extState().
|
441 |
+
*
|
442 |
+
* Using this variant avoids an expensive initialization step.
|
443 |
+
* It is only safe to call if the state buffer is known to be correctly initialized already
|
444 |
+
* (see above comment on LZ4_resetStream_fast() for a definition of "correctly initialized").
|
445 |
+
* From a high level, the difference is that
|
446 |
+
* this function initializes the provided state with a call to something like LZ4_resetStream_fast()
|
447 |
+
* while LZ4_compress_fast_extState() starts with a call to LZ4_resetStream().
|
448 |
+
*/
|
449 |
+
LZ4LIB_STATIC_API int LZ4_compress_fast_extState_fastReset (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration);
|
450 |
+
|
451 |
+
/*! LZ4_attach_dictionary() :
|
452 |
+
* This is an experimental API that allows
|
453 |
+
* efficient use of a static dictionary many times.
|
454 |
+
*
|
455 |
+
* Rather than re-loading the dictionary buffer into a working context before
|
456 |
+
* each compression, or copying a pre-loaded dictionary's LZ4_stream_t into a
|
457 |
+
* working LZ4_stream_t, this function introduces a no-copy setup mechanism,
|
458 |
+
* in which the working stream references the dictionary stream in-place.
|
459 |
+
*
|
460 |
+
* Several assumptions are made about the state of the dictionary stream.
|
461 |
+
* Currently, only streams which have been prepared by LZ4_loadDict() should
|
462 |
+
* be expected to work.
|
463 |
+
*
|
464 |
+
* Alternatively, the provided dictionaryStream may be NULL,
|
465 |
+
* in which case any existing dictionary stream is unset.
|
466 |
+
*
|
467 |
+
* If a dictionary is provided, it replaces any pre-existing stream history.
|
468 |
+
* The dictionary contents are the only history that can be referenced and
|
469 |
+
* logically immediately precede the data compressed in the first subsequent
|
470 |
+
* compression call.
|
471 |
+
*
|
472 |
+
* The dictionary will only remain attached to the working stream through the
|
473 |
+
* first compression call, at the end of which it is cleared. The dictionary
|
474 |
+
* stream (and source buffer) must remain in-place / accessible / unchanged
|
475 |
+
* through the completion of the first compression call on the stream.
|
476 |
+
*/
|
477 |
+
LZ4LIB_STATIC_API void LZ4_attach_dictionary(LZ4_stream_t* workingStream, const LZ4_stream_t* dictionaryStream);
|
478 |
+
|
479 |
+
|
480 |
+
/*! In-place compression and decompression
|
481 |
+
*
|
482 |
+
* It's possible to have input and output sharing the same buffer,
|
483 |
+
* for highly contrained memory environments.
|
484 |
+
* In both cases, it requires input to lay at the end of the buffer,
|
485 |
+
* and decompression to start at beginning of the buffer.
|
486 |
+
* Buffer size must feature some margin, hence be larger than final size.
|
487 |
+
*
|
488 |
+
* |<------------------------buffer--------------------------------->|
|
489 |
+
* |<-----------compressed data--------->|
|
490 |
+
* |<-----------decompressed size------------------>|
|
491 |
+
* |<----margin---->|
|
492 |
+
*
|
493 |
+
* This technique is more useful for decompression,
|
494 |
+
* since decompressed size is typically larger,
|
495 |
+
* and margin is short.
|
496 |
+
*
|
497 |
+
* In-place decompression will work inside any buffer
|
498 |
+
* which size is >= LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize).
|
499 |
+
* This presumes that decompressedSize > compressedSize.
|
500 |
+
* Otherwise, it means compression actually expanded data,
|
501 |
+
* and it would be more efficient to store such data with a flag indicating it's not compressed.
|
502 |
+
* This can happen when data is not compressible (already compressed, or encrypted).
|
503 |
+
*
|
504 |
+
* For in-place compression, margin is larger, as it must be able to cope with both
|
505 |
+
* history preservation, requiring input data to remain unmodified up to LZ4_DISTANCE_MAX,
|
506 |
+
* and data expansion, which can happen when input is not compressible.
|
507 |
+
* As a consequence, buffer size requirements are much higher,
|
508 |
+
* and memory savings offered by in-place compression are more limited.
|
509 |
+
*
|
510 |
+
* There are ways to limit this cost for compression :
|
511 |
+
* - Reduce history size, by modifying LZ4_DISTANCE_MAX.
|
512 |
+
* Note that it is a compile-time constant, so all compressions will apply this limit.
|
513 |
+
* Lower values will reduce compression ratio, except when input_size < LZ4_DISTANCE_MAX,
|
514 |
+
* so it's a reasonable trick when inputs are known to be small.
|
515 |
+
* - Require the compressor to deliver a "maximum compressed size".
|
516 |
+
* This is the `dstCapacity` parameter in `LZ4_compress*()`.
|
517 |
+
* When this size is < LZ4_COMPRESSBOUND(inputSize), then compression can fail,
|
518 |
+
* in which case, the return code will be 0 (zero).
|
519 |
+
* The caller must be ready for these cases to happen,
|
520 |
+
* and typically design a backup scheme to send data uncompressed.
|
521 |
+
* The combination of both techniques can significantly reduce
|
522 |
+
* the amount of margin required for in-place compression.
|
523 |
+
*
|
524 |
+
* In-place compression can work in any buffer
|
525 |
+
* which size is >= (maxCompressedSize)
|
526 |
+
* with maxCompressedSize == LZ4_COMPRESSBOUND(srcSize) for guaranteed compression success.
|
527 |
+
* LZ4_COMPRESS_INPLACE_BUFFER_SIZE() depends on both maxCompressedSize and LZ4_DISTANCE_MAX,
|
528 |
+
* so it's possible to reduce memory requirements by playing with them.
|
529 |
+
*/
|
530 |
+
|
531 |
+
#define LZ4_DECOMPRESS_INPLACE_MARGIN(compressedSize) (((compressedSize) >> 8) + 32)
|
532 |
+
#define LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize) ((decompressedSize) + LZ4_DECOMPRESS_INPLACE_MARGIN(decompressedSize)) /**< note: presumes that compressedSize < decompressedSize. note2: margin is overestimated a bit, since it could use compressedSize instead */
|
533 |
+
|
534 |
+
#ifndef LZ4_DISTANCE_MAX /* history window size; can be user-defined at compile time */
|
535 |
+
# define LZ4_DISTANCE_MAX 65535 /* set to maximum value by default */
|
536 |
+
#endif
|
537 |
+
|
538 |
+
#define LZ4_COMPRESS_INPLACE_MARGIN (LZ4_DISTANCE_MAX + 32) /* LZ4_DISTANCE_MAX can be safely replaced by srcSize when it's smaller */
|
539 |
+
#define LZ4_COMPRESS_INPLACE_BUFFER_SIZE(maxCompressedSize) ((maxCompressedSize) + LZ4_COMPRESS_INPLACE_MARGIN) /**< maxCompressedSize is generally LZ4_COMPRESSBOUND(inputSize), but can be set to any lower value, with the risk that compression can fail (return code 0(zero)) */
|
540 |
+
|
541 |
+
#endif /* LZ4_STATIC_3504398509 */
|
542 |
+
#endif /* LZ4_STATIC_LINKING_ONLY */
|
543 |
+
|
544 |
+
|
545 |
+
|
546 |
+
#ifndef LZ4_H_98237428734687
|
547 |
+
#define LZ4_H_98237428734687
|
548 |
+
|
549 |
+
/*-************************************************************
|
550 |
+
* PRIVATE DEFINITIONS
|
551 |
+
**************************************************************
|
552 |
+
* Do not use these definitions directly.
|
553 |
+
* They are only exposed to allow static allocation of `LZ4_stream_t` and `LZ4_streamDecode_t`.
|
554 |
+
* Accessing members will expose code to API and/or ABI break in future versions of the library.
|
555 |
+
**************************************************************/
|
556 |
+
#define LZ4_HASHLOG (LZ4_MEMORY_USAGE-2)
|
557 |
+
#define LZ4_HASHTABLESIZE (1 << LZ4_MEMORY_USAGE)
|
558 |
+
#define LZ4_HASH_SIZE_U32 (1 << LZ4_HASHLOG) /* required as macro for static allocation */
|
559 |
+
|
560 |
+
#if defined(__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
|
561 |
+
#include <stdint.h>
|
562 |
+
|
563 |
+
typedef struct LZ4_stream_t_internal LZ4_stream_t_internal;
|
564 |
+
struct LZ4_stream_t_internal {
|
565 |
+
uint32_t hashTable[LZ4_HASH_SIZE_U32];
|
566 |
+
uint32_t currentOffset;
|
567 |
+
uint16_t dirty;
|
568 |
+
uint16_t tableType;
|
569 |
+
const uint8_t* dictionary;
|
570 |
+
const LZ4_stream_t_internal* dictCtx;
|
571 |
+
uint32_t dictSize;
|
572 |
+
};
|
573 |
+
|
574 |
+
typedef struct {
|
575 |
+
const uint8_t* externalDict;
|
576 |
+
size_t extDictSize;
|
577 |
+
const uint8_t* prefixEnd;
|
578 |
+
size_t prefixSize;
|
579 |
+
} LZ4_streamDecode_t_internal;
|
580 |
+
|
581 |
+
#else
|
582 |
+
|
583 |
+
typedef struct LZ4_stream_t_internal LZ4_stream_t_internal;
|
584 |
+
struct LZ4_stream_t_internal {
|
585 |
+
unsigned int hashTable[LZ4_HASH_SIZE_U32];
|
586 |
+
unsigned int currentOffset;
|
587 |
+
unsigned short dirty;
|
588 |
+
unsigned short tableType;
|
589 |
+
const unsigned char* dictionary;
|
590 |
+
const LZ4_stream_t_internal* dictCtx;
|
591 |
+
unsigned int dictSize;
|
592 |
+
};
|
593 |
+
|
594 |
+
typedef struct {
|
595 |
+
const unsigned char* externalDict;
|
596 |
+
const unsigned char* prefixEnd;
|
597 |
+
size_t extDictSize;
|
598 |
+
size_t prefixSize;
|
599 |
+
} LZ4_streamDecode_t_internal;
|
600 |
+
|
601 |
+
#endif
|
602 |
+
|
603 |
+
/*! LZ4_stream_t :
|
604 |
+
* information structure to track an LZ4 stream.
|
605 |
+
* LZ4_stream_t can also be created using LZ4_createStream(), which is recommended.
|
606 |
+
* The structure definition can be convenient for static allocation
|
607 |
+
* (on stack, or as part of larger structure).
|
608 |
+
* Init this structure with LZ4_initStream() before first use.
|
609 |
+
* note : only use this definition in association with static linking !
|
610 |
+
* this definition is not API/ABI safe, and may change in a future version.
|
611 |
+
*/
|
612 |
+
#define LZ4_STREAMSIZE_U64 ((1 << (LZ4_MEMORY_USAGE-3)) + 4 + ((sizeof(void*)==16) ? 4 : 0) /*AS-400*/ )
|
613 |
+
#define LZ4_STREAMSIZE (LZ4_STREAMSIZE_U64 * sizeof(unsigned long long))
|
614 |
+
union LZ4_stream_u {
|
615 |
+
unsigned long long table[LZ4_STREAMSIZE_U64];
|
616 |
+
LZ4_stream_t_internal internal_donotuse;
|
617 |
+
} ; /* previously typedef'd to LZ4_stream_t */
|
618 |
+
|
619 |
+
/*! LZ4_initStream() : v1.9.0+
|
620 |
+
* An LZ4_stream_t structure must be initialized at least once.
|
621 |
+
* This is automatically done when invoking LZ4_createStream(),
|
622 |
+
* but it's not when the structure is simply declared on stack (for example).
|
623 |
+
*
|
624 |
+
* Use LZ4_initStream() to properly initialize a newly declared LZ4_stream_t.
|
625 |
+
* It can also initialize any arbitrary buffer of sufficient size,
|
626 |
+
* and will @return a pointer of proper type upon initialization.
|
627 |
+
*
|
628 |
+
* Note : initialization fails if size and alignment conditions are not respected.
|
629 |
+
* In which case, the function will @return NULL.
|
630 |
+
* Note2: An LZ4_stream_t structure guarantees correct alignment and size.
|
631 |
+
* Note3: Before v1.9.0, use LZ4_resetStream() instead
|
632 |
+
*/
|
633 |
+
LZ4LIB_API LZ4_stream_t* LZ4_initStream (void* buffer, size_t size);
|
634 |
+
|
635 |
+
|
636 |
+
/*! LZ4_streamDecode_t :
|
637 |
+
* information structure to track an LZ4 stream during decompression.
|
638 |
+
* init this structure using LZ4_setStreamDecode() before first use.
|
639 |
+
* note : only use in association with static linking !
|
640 |
+
* this definition is not API/ABI safe,
|
641 |
+
* and may change in a future version !
|
642 |
+
*/
|
643 |
+
#define LZ4_STREAMDECODESIZE_U64 (4 + ((sizeof(void*)==16) ? 2 : 0) /*AS-400*/ )
|
644 |
+
#define LZ4_STREAMDECODESIZE (LZ4_STREAMDECODESIZE_U64 * sizeof(unsigned long long))
|
645 |
+
union LZ4_streamDecode_u {
|
646 |
+
unsigned long long table[LZ4_STREAMDECODESIZE_U64];
|
647 |
+
LZ4_streamDecode_t_internal internal_donotuse;
|
648 |
+
} ; /* previously typedef'd to LZ4_streamDecode_t */
|
649 |
+
|
650 |
+
|
651 |
+
|
652 |
+
/*-************************************
|
653 |
+
* Obsolete Functions
|
654 |
+
**************************************/
|
655 |
+
|
656 |
+
/*! Deprecation warnings
|
657 |
+
*
|
658 |
+
* Deprecated functions make the compiler generate a warning when invoked.
|
659 |
+
* This is meant to invite users to update their source code.
|
660 |
+
* Should deprecation warnings be a problem, it is generally possible to disable them,
|
661 |
+
* typically with -Wno-deprecated-declarations for gcc
|
662 |
+
* or _CRT_SECURE_NO_WARNINGS in Visual.
|
663 |
+
*
|
664 |
+
* Another method is to define LZ4_DISABLE_DEPRECATE_WARNINGS
|
665 |
+
* before including the header file.
|
666 |
+
*/
|
667 |
+
#ifdef LZ4_DISABLE_DEPRECATE_WARNINGS
|
668 |
+
# define LZ4_DEPRECATED(message) /* disable deprecation warnings */
|
669 |
+
#else
|
670 |
+
# define LZ4_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)
|
671 |
+
# if defined (__cplusplus) && (__cplusplus >= 201402) /* C++14 or greater */
|
672 |
+
# define LZ4_DEPRECATED(message) [[deprecated(message)]]
|
673 |
+
# elif (LZ4_GCC_VERSION >= 405) || defined(__clang__)
|
674 |
+
# define LZ4_DEPRECATED(message) __attribute__((deprecated(message)))
|
675 |
+
# elif (LZ4_GCC_VERSION >= 301)
|
676 |
+
# define LZ4_DEPRECATED(message) __attribute__((deprecated))
|
677 |
+
# elif defined(_MSC_VER)
|
678 |
+
# define LZ4_DEPRECATED(message) __declspec(deprecated(message))
|
679 |
+
# else
|
680 |
+
# pragma message("WARNING: You need to implement LZ4_DEPRECATED for this compiler")
|
681 |
+
# define LZ4_DEPRECATED(message)
|
682 |
+
# endif
|
683 |
+
#endif /* LZ4_DISABLE_DEPRECATE_WARNINGS */
|
684 |
+
|
685 |
+
/* Obsolete compression functions */
|
686 |
+
LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress (const char* src, char* dest, int srcSize);
|
687 |
+
LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress_limitedOutput (const char* src, char* dest, int srcSize, int maxOutputSize);
|
688 |
+
LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_withState (void* state, const char* source, char* dest, int inputSize);
|
689 |
+
LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_limitedOutput_withState (void* state, const char* source, char* dest, int inputSize, int maxOutputSize);
|
690 |
+
LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize);
|
691 |
+
LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize, int maxOutputSize);
|
692 |
+
|
693 |
+
/* Obsolete decompression functions */
|
694 |
+
LZ4_DEPRECATED("use LZ4_decompress_fast() instead") LZ4LIB_API int LZ4_uncompress (const char* source, char* dest, int outputSize);
|
695 |
+
LZ4_DEPRECATED("use LZ4_decompress_safe() instead") LZ4LIB_API int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize);
|
696 |
+
|
697 |
+
/* Obsolete streaming functions; degraded functionality; do not use!
|
698 |
+
*
|
699 |
+
* In order to perform streaming compression, these functions depended on data
|
700 |
+
* that is no longer tracked in the state. They have been preserved as well as
|
701 |
+
* possible: using them will still produce a correct output. However, they don't
|
702 |
+
* actually retain any history between compression calls. The compression ratio
|
703 |
+
* achieved will therefore be no better than compressing each chunk
|
704 |
+
* independently.
|
705 |
+
*/
|
706 |
+
LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API void* LZ4_create (char* inputBuffer);
|
707 |
+
LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API int LZ4_sizeofStreamState(void);
|
708 |
+
LZ4_DEPRECATED("Use LZ4_resetStream() instead") LZ4LIB_API int LZ4_resetStreamState(void* state, char* inputBuffer);
|
709 |
+
LZ4_DEPRECATED("Use LZ4_saveDict() instead") LZ4LIB_API char* LZ4_slideInputBuffer (void* state);
|
710 |
+
|
711 |
+
/* Obsolete streaming decoding functions */
|
712 |
+
LZ4_DEPRECATED("use LZ4_decompress_safe_usingDict() instead") LZ4LIB_API int LZ4_decompress_safe_withPrefix64k (const char* src, char* dst, int compressedSize, int maxDstSize);
|
713 |
+
LZ4_DEPRECATED("use LZ4_decompress_fast_usingDict() instead") LZ4LIB_API int LZ4_decompress_fast_withPrefix64k (const char* src, char* dst, int originalSize);
|
714 |
+
|
715 |
+
/*! LZ4_decompress_fast() : **unsafe!**
|
716 |
+
* These functions used to be faster than LZ4_decompress_safe(),
|
717 |
+
* but it has changed, and they are now slower than LZ4_decompress_safe().
|
718 |
+
* This is because LZ4_decompress_fast() doesn't know the input size,
|
719 |
+
* and therefore must progress more cautiously in the input buffer to not read beyond the end of block.
|
720 |
+
* On top of that `LZ4_decompress_fast()` is not protected vs malformed or malicious inputs, making it a security liability.
|
721 |
+
* As a consequence, LZ4_decompress_fast() is strongly discouraged, and deprecated.
|
722 |
+
*
|
723 |
+
* The last remaining LZ4_decompress_fast() specificity is that
|
724 |
+
* it can decompress a block without knowing its compressed size.
|
725 |
+
* Such functionality could be achieved in a more secure manner,
|
726 |
+
* by also providing the maximum size of input buffer,
|
727 |
+
* but it would require new prototypes, and adaptation of the implementation to this new use case.
|
728 |
+
*
|
729 |
+
* Parameters:
|
730 |
+
* originalSize : is the uncompressed size to regenerate.
|
731 |
+
* `dst` must be already allocated, its size must be >= 'originalSize' bytes.
|
732 |
+
* @return : number of bytes read from source buffer (== compressed size).
|
733 |
+
* The function expects to finish at block's end exactly.
|
734 |
+
* If the source stream is detected malformed, the function stops decoding and returns a negative result.
|
735 |
+
* note : LZ4_decompress_fast*() requires originalSize. Thanks to this information, it never writes past the output buffer.
|
736 |
+
* However, since it doesn't know its 'src' size, it may read an unknown amount of input, past input buffer bounds.
|
737 |
+
* Also, since match offsets are not validated, match reads from 'src' may underflow too.
|
738 |
+
* These issues never happen if input (compressed) data is correct.
|
739 |
+
* But they may happen if input data is invalid (error or intentional tampering).
|
740 |
+
* As a consequence, use these functions in trusted environments with trusted data **only**.
|
741 |
+
*/
|
742 |
+
|
743 |
+
LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe() instead")
|
744 |
+
LZ4LIB_API int LZ4_decompress_fast (const char* src, char* dst, int originalSize);
|
745 |
+
LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_continue() instead")
|
746 |
+
LZ4LIB_API int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int originalSize);
|
747 |
+
LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_usingDict() instead")
|
748 |
+
LZ4LIB_API int LZ4_decompress_fast_usingDict (const char* src, char* dst, int originalSize, const char* dictStart, int dictSize);
|
749 |
+
|
750 |
+
/*! LZ4_resetStream() :
|
751 |
+
* An LZ4_stream_t structure must be initialized at least once.
|
752 |
+
* This is done with LZ4_initStream(), or LZ4_resetStream().
|
753 |
+
* Consider switching to LZ4_initStream(),
|
754 |
+
* invoking LZ4_resetStream() will trigger deprecation warnings in the future.
|
755 |
+
*/
|
756 |
+
LZ4LIB_API void LZ4_resetStream (LZ4_stream_t* streamPtr);
|
757 |
+
|
758 |
+
|
759 |
+
#endif /* LZ4_H_98237428734687 */
|
760 |
+
|
761 |
+
|
762 |
+
#if defined (__cplusplus)
|
763 |
+
}
|
764 |
+
#endif
|
include/dataset_creation/dataset_manipulation/bytes_to_file.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <string>
|
4 |
+
#include <fstream>
|
5 |
+
#include "../../../libraries/protobuf/build/midi.pb.h"
|
6 |
+
#include "../compression/lz4.h"
|
7 |
+
|
8 |
+
|
9 |
+
namespace dataset_manipulation {
|
10 |
+
class BytesToFile{
|
11 |
+
private:
|
12 |
+
std::string filepath_;
|
13 |
+
std::string header_filepath_;
|
14 |
+
std::fstream file_stream_;
|
15 |
+
std::fstream header_file_stream_;
|
16 |
+
midi::Dataset dataset_split_protobuf_;
|
17 |
+
int flush_count_;
|
18 |
+
bool can_write;
|
19 |
+
|
20 |
+
public:
|
21 |
+
BytesToFile(std::string external_filepath_);
|
22 |
+
void enableWrite();
|
23 |
+
void appendBytesToFileStream(std::string& bytes_as_string, size_t split_id);
|
24 |
+
void writeFile();
|
25 |
+
void close();
|
26 |
+
};
|
27 |
+
}
|
libraries/midifile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 838c62c4a13245ced8e13a84e6c2a1994664acd5
|
libraries/protobuf/CMakeLists.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 3.8)
|
2 |
+
|
3 |
+
project(midigpt_proto)
|
4 |
+
|
5 |
+
set(PROTO_DEF
|
6 |
+
src/enum.proto
|
7 |
+
src/midi.proto
|
8 |
+
src/midi_internal.proto
|
9 |
+
src/track_type.proto
|
10 |
+
src/feature_extraction.proto)
|
11 |
+
|
12 |
+
find_package(Protobuf REQUIRED)
|
13 |
+
|
14 |
+
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS
|
15 |
+
${PROTO_DEF}
|
16 |
+
PROTOC_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} # it's the default but it does not hurt to be explicit here...
|
17 |
+
)
|
18 |
+
|
19 |
+
add_library(midigpt_proto
|
20 |
+
${PROTO_SRCS}
|
21 |
+
${PROTO_HDRS})
|
22 |
+
|
23 |
+
target_include_directories(midigpt_proto
|
24 |
+
PUBLIC
|
25 |
+
${Protobuf_INCLUDE_DIRS}
|
26 |
+
${CMAKE_CURRENT_BINARY_DIR} # for generated protobuf files
|
27 |
+
)
|
28 |
+
target_link_libraries(midigpt_proto ${Protobuf_LIBRARIES})
|
libraries/protobuf/CMakeSettings.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"configurations": [
|
3 |
+
{
|
4 |
+
"name": "x64-Debug",
|
5 |
+
"generator": "Ninja",
|
6 |
+
"configurationType": "Debug",
|
7 |
+
"inheritEnvironments": [ "msvc_x64_x64" ],
|
8 |
+
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
9 |
+
"installRoot": "${projectDir}\\out\\install\\${name}",
|
10 |
+
"cmakeCommandArgs": "",
|
11 |
+
"buildCommandArgs": "",
|
12 |
+
"ctestCommandArgs": ""
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"name": "WSL-GCC-Debug",
|
16 |
+
"generator": "Ninja",
|
17 |
+
"configurationType": "Debug",
|
18 |
+
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
19 |
+
"installRoot": "${projectDir}\\out\\install\\${name}",
|
20 |
+
"cmakeExecutable": "cmake",
|
21 |
+
"cmakeCommandArgs": "",
|
22 |
+
"buildCommandArgs": "",
|
23 |
+
"ctestCommandArgs": "",
|
24 |
+
"inheritEnvironments": [ "linux_x64" ],
|
25 |
+
"wslPath": "${defaultWSLPath}",
|
26 |
+
"variables": []
|
27 |
+
}
|
28 |
+
]
|
29 |
+
}
|
libraries/protobuf/include/proto_library.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
namespace midigptProto {
|
4 |
+
void testMmmProto();
|
5 |
+
}
|
libraries/protobuf/src/enum.proto
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto2";
|
2 |
+
|
3 |
+
package midi;
|
4 |
+
|
5 |
+
enum TOKEN_TYPE {
|
6 |
+
TOKEN_PIECE_START = 0;
|
7 |
+
TOKEN_NOTE_ONSET = 1;
|
8 |
+
TOKEN_PITCH = 3;
|
9 |
+
TOKEN_VELOCITY = 5;
|
10 |
+
TOKEN_TIME_ABSOLUTE_POS = 7;
|
11 |
+
TOKEN_INSTRUMENT = 8;
|
12 |
+
TOKEN_BAR = 9;
|
13 |
+
TOKEN_BAR_END = 10;
|
14 |
+
TOKEN_TRACK = 11;
|
15 |
+
TOKEN_TRACK_END = 12;
|
16 |
+
TOKEN_DRUM_TRACK = 13;
|
17 |
+
TOKEN_FILL_IN = 14;
|
18 |
+
TOKEN_FILL_IN_PLACEHOLDER = 15;
|
19 |
+
TOKEN_FILL_IN_START = 16;
|
20 |
+
TOKEN_FILL_IN_END = 17;
|
21 |
+
TOKEN_VELOCITY_LEVEL = 19;
|
22 |
+
TOKEN_GENRE = 20;
|
23 |
+
TOKEN_DENSITY_LEVEL = 21;
|
24 |
+
TOKEN_TIME_SIGNATURE = 22;
|
25 |
+
TOKEN_NOTE_DURATION = 26;
|
26 |
+
TOKEN_AV_POLYPHONY = 27;
|
27 |
+
TOKEN_MIN_POLYPHONY = 28;
|
28 |
+
TOKEN_MAX_POLYPHONY = 29;
|
29 |
+
TOKEN_MIN_NOTE_DURATION = 30;
|
30 |
+
TOKEN_MAX_NOTE_DURATION = 31;
|
31 |
+
TOKEN_NUM_BARS = 32;
|
32 |
+
TOKEN_MIN_POLYPHONY_HARD = 33;
|
33 |
+
TOKEN_MAX_POLYPHONY_HARD = 34;
|
34 |
+
TOKEN_MIN_NOTE_DURATION_HARD = 35;
|
35 |
+
TOKEN_MAX_NOTE_DURATION_HARD = 36;
|
36 |
+
TOKEN_BAR_LEVEL_ONSET_DENSITY = 40;
|
37 |
+
TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MIN = 41;
|
38 |
+
TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MAX = 42;
|
39 |
+
TOKEN_TRACK_LEVEL_ONSET_DENSITY = 43;
|
40 |
+
TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MIN = 44;
|
41 |
+
TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MAX = 45;
|
42 |
+
TOKEN_TRACK_LEVEL_ONSET_DENSITY_MIN = 46;
|
43 |
+
TOKEN_TRACK_LEVEL_ONSET_DENSITY_MAX = 47;
|
44 |
+
TOKEN_TRACK_LEVEL_PITCH_RANGE_MIN = 48;
|
45 |
+
TOKEN_TRACK_LEVEL_PITCH_RANGE_MAX = 49;
|
46 |
+
TOKEN_CONTAINS_NOTE_DURATION_THIRTY_SECOND = 59;
|
47 |
+
TOKEN_CONTAINS_NOTE_DURATION_SIXTEENTH = 60;
|
48 |
+
TOKEN_CONTAINS_NOTE_DURATION_EIGHTH = 61;
|
49 |
+
TOKEN_CONTAINS_NOTE_DURATION_QUARTER = 62;
|
50 |
+
TOKEN_CONTAINS_NOTE_DURATION_HALF = 63;
|
51 |
+
TOKEN_CONTAINS_NOTE_DURATION_WHOLE = 64;
|
52 |
+
TOKEN_DELTA = 67;
|
53 |
+
TOKEN_DELTA_DIRECTION = 68;
|
54 |
+
TOKEN_NONE = 72;
|
55 |
+
}
|
56 |
+
|
57 |
+
enum ATTRIBUTE_CONTROL_TYPE {
|
58 |
+
ATTRIBUTE_CONTROL_NOTE_DENSITY = 0;
|
59 |
+
ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_POLYPHONY = 1;
|
60 |
+
ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_DENSITY = 2;
|
61 |
+
ATTRIBUTE_CONTROL_PITCH_RANGE = 3;
|
62 |
+
ATTRIBUTE_CONTROL_GENRE = 4;
|
63 |
+
ATTRIBUTE_CONTROL_POLYPHONY_QUANTILE = 5;
|
64 |
+
ATTRIBUTE_CONTROL_NOTE_DURATION_QUANTILE = 6;
|
65 |
+
ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_DENSITY = 7;
|
66 |
+
ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_POLYPHONY = 8;
|
67 |
+
ATTRIBUTE_CONTROL_TRACK_LEVEL_NOTE_DURATION = 9;
|
68 |
+
ATTRIBUTE_CONTROL_END = 10;
|
69 |
+
}
|
70 |
+
|
71 |
+
enum GenreMusicmap {
|
72 |
+
GENRE_MUSICMAP_ANY = 0;
|
73 |
+
GENRE_MUSICMAP_ALTERNATIVE_ROCK = 1;
|
74 |
+
GENRE_MUSICMAP_AMBIENT = 2;
|
75 |
+
GENRE_MUSICMAP_BLUES = 3;
|
76 |
+
GENRE_MUSICMAP_BREAKBEAT = 4;
|
77 |
+
GENRE_MUSICMAP_CLASSICAL = 5;
|
78 |
+
GENRE_MUSICMAP_CLASSIC_ROCK = 6;
|
79 |
+
GENRE_MUSICMAP_CONTEMPORARY_ROCK = 7;
|
80 |
+
GENRE_MUSICMAP_COUNTRY = 8;
|
81 |
+
GENRE_MUSICMAP_DRUM_N_BASS = 9;
|
82 |
+
GENRE_MUSICMAP_FOLK = 10;
|
83 |
+
GENRE_MUSICMAP_GOSPEL = 11;
|
84 |
+
GENRE_MUSICMAP_HARDCORE_PUNK = 12;
|
85 |
+
GENRE_MUSICMAP_HARDCORE_TECHNO = 13;
|
86 |
+
GENRE_MUSICMAP_HEAVY_METAL = 14;
|
87 |
+
GENRE_MUSICMAP_HIP_HOP = 15;
|
88 |
+
GENRE_MUSICMAP_HOUSE = 16;
|
89 |
+
GENRE_MUSICMAP_INDUSTRIAL = 17;
|
90 |
+
GENRE_MUSICMAP_JAZZ = 18;
|
91 |
+
GENRE_MUSICMAP_LATIN = 19;
|
92 |
+
GENRE_MUSICMAP_POP = 20;
|
93 |
+
GENRE_MUSICMAP_PUNK = 21;
|
94 |
+
GENRE_MUSICMAP_PUNK_ROCK = 22;
|
95 |
+
GENRE_MUSICMAP_RANDB = 23;
|
96 |
+
GENRE_MUSICMAP_REGGAE = 24;
|
97 |
+
GENRE_MUSICMAP_ROCK_N_ROLL = 25;
|
98 |
+
GENRE_MUSICMAP_TECHNO = 26;
|
99 |
+
GENRE_MUSICMAP_TRANCE = 27;
|
100 |
+
GENRE_MUSICMAP_UTILITY = 28;
|
101 |
+
GENRE_MUSICMAP_WORLD = 29;
|
102 |
+
GENRE_MUSICMAP_NONE = 30;
|
103 |
+
};
|
104 |
+
|
105 |
+
enum GM_CATEGORY {
|
106 |
+
GM_CATEGORY_MONO = 0;
|
107 |
+
GM_CATEGORY_POLY = 1;
|
108 |
+
GM_CATEGORY_SOUND_FX = 2;
|
109 |
+
GM_CATEGORY_PERC = 3;
|
110 |
+
};
|
111 |
+
|
112 |
+
enum GM_TYPE {
|
113 |
+
any = 0;
|
114 |
+
piano = 1;
|
115 |
+
chromatic_perc = 2;
|
116 |
+
organ = 3;
|
117 |
+
guitar = 4;
|
118 |
+
bass = 5;
|
119 |
+
strings = 6;
|
120 |
+
ensemble = 7;
|
121 |
+
brass = 8;
|
122 |
+
reed = 9;
|
123 |
+
pipe = 10;
|
124 |
+
synth_lead = 11;
|
125 |
+
synth_pad = 12;
|
126 |
+
synth_effects = 13;
|
127 |
+
ethnic = 14;
|
128 |
+
percussive = 15;
|
129 |
+
sound_fx = 16;
|
130 |
+
no_drums = 17;
|
131 |
+
drums = 18;
|
132 |
+
acoustic_grand_piano = 19;
|
133 |
+
bright_acoustic_piano = 20;
|
134 |
+
electric_grand_piano = 21;
|
135 |
+
honky_tonk_piano = 22;
|
136 |
+
electric_piano_1 = 23;
|
137 |
+
electric_piano_2 = 24;
|
138 |
+
harpsichord = 25;
|
139 |
+
clavi = 26;
|
140 |
+
celesta = 27;
|
141 |
+
glockenspiel = 28;
|
142 |
+
music_box = 29;
|
143 |
+
vibraphone = 30;
|
144 |
+
marimba = 31;
|
145 |
+
xylophone = 32;
|
146 |
+
tubular_bells = 33;
|
147 |
+
dulcimer = 34;
|
148 |
+
drawbar_organ = 35;
|
149 |
+
percussive_organ = 36;
|
150 |
+
rock_organ = 37;
|
151 |
+
church_organ = 38;
|
152 |
+
reed_organ = 39;
|
153 |
+
accordion = 40;
|
154 |
+
harmonica = 41;
|
155 |
+
tango_accordion = 42;
|
156 |
+
acoustic_guitar_nylon = 43;
|
157 |
+
acoustic_guitar_steel = 44;
|
158 |
+
electric_guitar_jazz = 45;
|
159 |
+
electric_guitar_clean = 46;
|
160 |
+
electric_guitar_muted = 47;
|
161 |
+
overdriven_guitar = 48;
|
162 |
+
distortion_guitar = 49;
|
163 |
+
guitar_harmonics = 50;
|
164 |
+
acoustic_bass = 51;
|
165 |
+
electric_bass_finger = 52;
|
166 |
+
electric_bass_pick = 53;
|
167 |
+
fretless_bass = 54;
|
168 |
+
slap_bass_1 = 55;
|
169 |
+
slap_bass_2 = 56;
|
170 |
+
synth_bass_1 = 57;
|
171 |
+
synth_bass_2 = 58;
|
172 |
+
violin = 59;
|
173 |
+
viola = 60;
|
174 |
+
cello = 61;
|
175 |
+
contrabass = 62;
|
176 |
+
tremolo_strings = 63;
|
177 |
+
pizzicato_strings = 64;
|
178 |
+
orchestral_harp = 65;
|
179 |
+
timpani = 66;
|
180 |
+
string_ensemble_1 = 67;
|
181 |
+
string_ensemble_2 = 68;
|
182 |
+
synth_strings_1 = 69;
|
183 |
+
synth_strings_2 = 70;
|
184 |
+
choir_aahs = 71;
|
185 |
+
voice_oohs = 72;
|
186 |
+
synth_voice = 73;
|
187 |
+
orchestra_hit = 74;
|
188 |
+
trumpet = 75;
|
189 |
+
trombone = 76;
|
190 |
+
tuba = 77;
|
191 |
+
muted_trumpet = 78;
|
192 |
+
french_horn = 79;
|
193 |
+
brass_section = 80;
|
194 |
+
synth_brass_1 = 81;
|
195 |
+
synth_brass_2 = 82;
|
196 |
+
soprano_sax = 83;
|
197 |
+
alto_sax = 84;
|
198 |
+
tenor_sax = 85;
|
199 |
+
baritone_sax = 86;
|
200 |
+
oboe = 87;
|
201 |
+
english_horn = 88;
|
202 |
+
bassoon = 89;
|
203 |
+
clarinet = 90;
|
204 |
+
piccolo = 91;
|
205 |
+
flute = 92;
|
206 |
+
recorder = 93;
|
207 |
+
pan_flute = 94;
|
208 |
+
blown_bottle = 95;
|
209 |
+
shakuhachi = 96;
|
210 |
+
whistle = 97;
|
211 |
+
ocarina = 98;
|
212 |
+
lead_1_square = 99;
|
213 |
+
lead_2_sawtooth = 100;
|
214 |
+
lead_3_calliope = 101;
|
215 |
+
lead_4_chiff = 102;
|
216 |
+
lead_5_charang = 103;
|
217 |
+
lead_6_voice = 104;
|
218 |
+
lead_7_fifths = 105;
|
219 |
+
lead_8_bass__lead = 106;
|
220 |
+
pad_1_new_age = 107;
|
221 |
+
pad_2_warm = 108;
|
222 |
+
pad_3_polysynth = 109;
|
223 |
+
pad_4_choir = 110;
|
224 |
+
pad_5_bowed = 111;
|
225 |
+
pad_6_metallic = 112;
|
226 |
+
pad_7_halo = 113;
|
227 |
+
pad_8_sweep = 114;
|
228 |
+
fx_1_rain = 115;
|
229 |
+
fx_2_soundtrack = 116;
|
230 |
+
fx_3_crystal = 117;
|
231 |
+
fx_4_atmosphere = 118;
|
232 |
+
fx_5_brightness = 119;
|
233 |
+
fx_6_goblins = 120;
|
234 |
+
fx_7_echoes = 121;
|
235 |
+
fx_8_sci_fi = 122;
|
236 |
+
sitar = 123;
|
237 |
+
banjo = 124;
|
238 |
+
shamisen = 125;
|
239 |
+
koto = 126;
|
240 |
+
kalimba = 127;
|
241 |
+
bag_pipe = 128;
|
242 |
+
fiddle = 129;
|
243 |
+
shanai = 130;
|
244 |
+
tinkle_bell = 131;
|
245 |
+
agogo = 132;
|
246 |
+
steel_drums = 133;
|
247 |
+
woodblock = 134;
|
248 |
+
taiko_drum = 135;
|
249 |
+
melodic_tom = 136;
|
250 |
+
synth_drum = 137;
|
251 |
+
reverse_cymbal = 138;
|
252 |
+
guitar_fret_noise = 139;
|
253 |
+
breath_noise = 140;
|
254 |
+
seashore = 141;
|
255 |
+
bird_tweet = 142;
|
256 |
+
telephone_ring = 143;
|
257 |
+
helicopter = 144;
|
258 |
+
applause = 145;
|
259 |
+
gunshot = 146;
|
260 |
+
drum_0 = 147;
|
261 |
+
drum_1 = 148;
|
262 |
+
drum_2 = 149;
|
263 |
+
drum_3 = 150;
|
264 |
+
drum_4 = 151;
|
265 |
+
drum_5 = 152;
|
266 |
+
drum_6 = 153;
|
267 |
+
drum_7 = 154;
|
268 |
+
drum_8 = 155;
|
269 |
+
drum_9 = 156;
|
270 |
+
drum_10 = 157;
|
271 |
+
drum_11 = 158;
|
272 |
+
drum_12 = 159;
|
273 |
+
drum_13 = 160;
|
274 |
+
drum_14 = 161;
|
275 |
+
drum_15 = 162;
|
276 |
+
drum_16 = 163;
|
277 |
+
drum_17 = 164;
|
278 |
+
drum_18 = 165;
|
279 |
+
drum_19 = 166;
|
280 |
+
drum_20 = 167;
|
281 |
+
drum_21 = 168;
|
282 |
+
drum_22 = 169;
|
283 |
+
drum_23 = 170;
|
284 |
+
drum_24 = 171;
|
285 |
+
drum_25 = 172;
|
286 |
+
drum_26 = 173;
|
287 |
+
drum_27 = 174;
|
288 |
+
drum_28 = 175;
|
289 |
+
drum_29 = 176;
|
290 |
+
drum_30 = 177;
|
291 |
+
drum_31 = 178;
|
292 |
+
drum_32 = 179;
|
293 |
+
drum_33 = 180;
|
294 |
+
drum_34 = 181;
|
295 |
+
drum_35 = 182;
|
296 |
+
drum_36 = 183;
|
297 |
+
drum_37 = 184;
|
298 |
+
drum_38 = 185;
|
299 |
+
drum_39 = 186;
|
300 |
+
drum_40 = 187;
|
301 |
+
drum_41 = 188;
|
302 |
+
drum_42 = 189;
|
303 |
+
drum_43 = 190;
|
304 |
+
drum_44 = 191;
|
305 |
+
drum_45 = 192;
|
306 |
+
drum_46 = 193;
|
307 |
+
drum_47 = 194;
|
308 |
+
drum_48 = 195;
|
309 |
+
drum_49 = 196;
|
310 |
+
drum_50 = 197;
|
311 |
+
drum_51 = 198;
|
312 |
+
drum_52 = 199;
|
313 |
+
drum_53 = 200;
|
314 |
+
drum_54 = 201;
|
315 |
+
drum_55 = 202;
|
316 |
+
drum_56 = 203;
|
317 |
+
drum_57 = 204;
|
318 |
+
drum_58 = 205;
|
319 |
+
drum_59 = 206;
|
320 |
+
drum_60 = 207;
|
321 |
+
drum_61 = 208;
|
322 |
+
drum_62 = 209;
|
323 |
+
drum_63 = 210;
|
324 |
+
drum_64 = 211;
|
325 |
+
drum_65 = 212;
|
326 |
+
drum_66 = 213;
|
327 |
+
drum_67 = 214;
|
328 |
+
drum_68 = 215;
|
329 |
+
drum_69 = 216;
|
330 |
+
drum_70 = 217;
|
331 |
+
drum_71 = 218;
|
332 |
+
drum_72 = 219;
|
333 |
+
drum_73 = 220;
|
334 |
+
drum_74 = 221;
|
335 |
+
drum_75 = 222;
|
336 |
+
drum_76 = 223;
|
337 |
+
drum_77 = 224;
|
338 |
+
drum_78 = 225;
|
339 |
+
drum_79 = 226;
|
340 |
+
drum_80 = 227;
|
341 |
+
drum_81 = 228;
|
342 |
+
drum_82 = 229;
|
343 |
+
drum_83 = 230;
|
344 |
+
drum_84 = 231;
|
345 |
+
drum_85 = 232;
|
346 |
+
drum_86 = 233;
|
347 |
+
drum_87 = 234;
|
348 |
+
drum_88 = 235;
|
349 |
+
drum_89 = 236;
|
350 |
+
drum_90 = 237;
|
351 |
+
drum_91 = 238;
|
352 |
+
drum_92 = 239;
|
353 |
+
drum_93 = 240;
|
354 |
+
drum_94 = 241;
|
355 |
+
drum_95 = 242;
|
356 |
+
drum_96 = 243;
|
357 |
+
drum_97 = 244;
|
358 |
+
drum_98 = 245;
|
359 |
+
drum_99 = 246;
|
360 |
+
drum_100 = 247;
|
361 |
+
drum_101 = 248;
|
362 |
+
drum_102 = 249;
|
363 |
+
drum_103 = 250;
|
364 |
+
drum_104 = 251;
|
365 |
+
drum_105 = 252;
|
366 |
+
drum_106 = 253;
|
367 |
+
drum_107 = 254;
|
368 |
+
drum_108 = 255;
|
369 |
+
drum_109 = 256;
|
370 |
+
drum_110 = 257;
|
371 |
+
drum_111 = 258;
|
372 |
+
drum_112 = 259;
|
373 |
+
drum_113 = 260;
|
374 |
+
drum_114 = 261;
|
375 |
+
drum_115 = 262;
|
376 |
+
drum_116 = 263;
|
377 |
+
drum_117 = 264;
|
378 |
+
drum_118 = 265;
|
379 |
+
drum_119 = 266;
|
380 |
+
drum_120 = 267;
|
381 |
+
drum_121 = 268;
|
382 |
+
drum_122 = 269;
|
383 |
+
drum_123 = 270;
|
384 |
+
drum_124 = 271;
|
385 |
+
drum_125 = 272;
|
386 |
+
drum_126 = 273;
|
387 |
+
drum_127 = 274;
|
388 |
+
};
|
389 |
+
|
libraries/protobuf/src/feature_extraction.proto
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto2";
|
2 |
+
|
3 |
+
package midi;
|
4 |
+
|
5 |
+
message PitchRange {
|
6 |
+
optional int32 instrument = 1;
|
7 |
+
optional int32 min = 2;
|
8 |
+
optional int32 max = 3;
|
9 |
+
}
|
10 |
+
|
11 |
+
message MetricDepth {
|
12 |
+
optional string filepath = 1;
|
13 |
+
optional int32 track_num = 2;
|
14 |
+
optional int32 instrument = 3;
|
15 |
+
optional bool is_drum = 4;
|
16 |
+
optional bool has_time_signatures = 5;
|
17 |
+
repeated int32 metric_depth = 6;
|
18 |
+
optional int32 tpq = 7;
|
19 |
+
}
|
20 |
+
|
21 |
+
message MedianMetricDepth {
|
22 |
+
optional string filepath = 1;
|
23 |
+
optional int32 track_num = 2;
|
24 |
+
optional int32 instrument = 3;
|
25 |
+
optional bool is_drum = 4;
|
26 |
+
optional bool has_time_signatures = 5;
|
27 |
+
optional int32 median_metric_depth = 6;
|
28 |
+
optional int32 tpq = 7;
|
29 |
+
}
|
30 |
+
|
31 |
+
message MostFrequentMetricDepth {
|
32 |
+
optional string filepath = 1;
|
33 |
+
optional int32 track_num = 2;
|
34 |
+
optional int32 instrument = 3;
|
35 |
+
optional bool is_drum = 4;
|
36 |
+
optional bool has_time_signatures = 5;
|
37 |
+
optional int32 most_frequent_metric_depth = 6;
|
38 |
+
optional int32 tpq = 7;
|
39 |
+
}
|
40 |
+
|
41 |
+
message DownbeatProportion {
|
42 |
+
optional string filepath = 1;
|
43 |
+
optional int32 track_num = 2;
|
44 |
+
optional int32 instrument = 3;
|
45 |
+
optional bool is_drum = 4;
|
46 |
+
optional float downbeat_proportion = 5;
|
47 |
+
}
|
48 |
+
|
49 |
+
message AlignedMetricDepth {
|
50 |
+
optional string filepath = 1;
|
51 |
+
optional int32 track_num = 2;
|
52 |
+
optional int32 instrument = 3;
|
53 |
+
optional bool is_drum = 4;
|
54 |
+
|
55 |
+
optional int32 aligned_offset = 5;
|
56 |
+
}
|
57 |
+
|
58 |
+
message SimultaneousOnset {
|
59 |
+
optional string filepath = 1;
|
60 |
+
optional int32 track_num = 2;
|
61 |
+
optional int32 instrument = 3;
|
62 |
+
optional bool is_drum = 4;
|
63 |
+
|
64 |
+
optional int32 simultaneous_onset_count = 5;
|
65 |
+
}
|
66 |
+
|
67 |
+
message BeatStability {
|
68 |
+
optional string filepath = 1;
|
69 |
+
//optional int32 track_num = 2;
|
70 |
+
//optional int32 instrument = 3;
|
71 |
+
//optional bool is_drum = 4;
|
72 |
+
|
73 |
+
optional float beat_stability_stdev = 5;
|
74 |
+
optional float beat_stability_median = 6;
|
75 |
+
}
|
76 |
+
|
77 |
+
message DrumPresence {
|
78 |
+
optional string filepath = 1;
|
79 |
+
optional float drum_presence = 2;
|
80 |
+
}
|
81 |
+
|
82 |
+
message Features {
|
83 |
+
repeated PitchRange pitch_range = 1;
|
84 |
+
repeated MetricDepth metric_depth = 2;
|
85 |
+
repeated DownbeatProportion downbeat_proportion = 3;
|
86 |
+
repeated AlignedMetricDepth aligned_metric_depth = 4;
|
87 |
+
repeated SimultaneousOnset simultaneous_onset = 5;
|
88 |
+
repeated BeatStability beat_stability = 6;
|
89 |
+
repeated DrumPresence drum_presence = 7;
|
90 |
+
repeated MedianMetricDepth median_metric_depth = 8;
|
91 |
+
repeated MostFrequentMetricDepth most_frequent_metric_depth = 9;
|
92 |
+
}
|
libraries/protobuf/src/midi.proto
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
|
3 |
+
# Introduction
|
4 |
+
|
5 |
+
To generate material three protobuf messages must be supplied to the midigpt::sample function. The specification for each of these messages (Piece, Status and SampleParam) is outlined in this document. Any field which starts with internal_ should be ignored, as these are for internal use only. Or if you really know what you are doing ;). There are examples of midi::Piece, midi::Status and midi::SampleParam objects in the docs folder. For working examples, please consult the testing suite, which is found in MMM_API/src/midigpt_api/test/unit.cpp.
|
6 |
+
|
7 |
+
# Functionality Overview
|
8 |
+
|
9 |
+
| Functionality | Scope | Description |
|
10 |
+
| ----------- | ----------- | ----------- |
|
11 |
+
| Velocity | Always Enabled | 32 levels of loudness for individual notes. |
|
12 |
+
| Instrument | Track | The General MIDI instrument (i.e. Timbre). |
|
13 |
+
| Max Polyphony Hard-Limit | Track | A hard-limit on the number of simultaneously sounding notes. |
|
14 |
+
| Note Duration (upper/lower soft bounds) | Non-Drum Tracks | Tells the model what the 15th (lower) and 85th (upper) quantiles of the note duration (i.e. Quarter, Whole) distrbution should be. |
|
15 |
+
| Polyphony (upper/lower soft bounds) | Non-Drum Tracks | Tells the model what the 15th (lower) and 85th (upper) quantiles of the polyphony distrbution should be. |
|
16 |
+
| Density (10 levels) | Drum Tracks | Tells the model the number of notes per bar to produce |
|
17 |
+
| Auto-regressive Sampling Mode | Track | When enabled, bars are always sampled in chronological order. |
|
18 |
+
| Time Signature | Bar | A unique time-signature can be specified for each bar. |
|
19 |
+
| Temperature | per API call | A higher value increases entropy of generated output. Temperature=1 applies no modification to the probabilities produced by the model. |
|
20 |
+
| Context size (model_dim) | per API call | The number of bars that the model can process in one API call |
|
21 |
+
|
22 |
+
|
23 |
+
# Parameter Constraints and Considerations
|
24 |
+
|
25 |
+
There are two sampling methods: autoregressive generation, where we progressively sample musical material forwards in time on each track; and conditional generation (bar-infilling), where generated material is conditioned on past and future material.
|
26 |
+
|
27 |
+
Note that a single call the the model midigpt:sample() may involve both autoregressive and conditional generation, as these can be specified on a per-track basis. These constraints are
|
28 |
+
|
29 |
+
## Sample Param Constraints
|
30 |
+
|
31 |
+
1. tracks_per_step :
|
32 |
+
- must be on range [1,number of tracks in piece]
|
33 |
+
|
34 |
+
2. bars_per_step :
|
35 |
+
- must be on the range [1,model_dim]
|
36 |
+
- for conditional generation it is ill-advised for the user to have bars_per_step == model_dim, as this means generation will not be conditioned on any bars
|
37 |
+
|
38 |
+
3. shuffle :
|
39 |
+
- this only applies in cases where one or more tracks are conditionally generated (i.e. resample = False && 1+ selected_bars = True)
|
40 |
+
|
41 |
+
4. percentage :
|
42 |
+
- this only applies in cases where one or more tracks are conditionally generated (i.e. resample = False && 1+ selected_bars = True)
|
43 |
+
|
44 |
+
## Status Constraints
|
45 |
+
|
46 |
+
1. density :
|
47 |
+
- this control only applies to drum tracks. This works will both infilling and autoregressive mode.
|
48 |
+
|
49 |
+
2. note duration / polyphony :
|
50 |
+
- this control only applies to non-drum tracks. This works will both infilling and autoregressive mode.
|
51 |
+
|
52 |
+
3. autoregressive :
|
53 |
+
- you can only enable autoregressive mode (resample = True) when all the bars are selected in a track.
|
54 |
+
- note you may have autoregressive disabled when all bars are selected in a track
|
55 |
+
|
56 |
+
4. ignore :
|
57 |
+
- bars which have 1+ selected_bars = True may not be ignored, as they are needed to condition the generation
|
58 |
+
|
59 |
+
|
60 |
+
# Protobuf Specification
|
61 |
+
|
62 |
+
*/
|
63 |
+
syntax = "proto2";
|
64 |
+
|
65 |
+
import "track_type.proto";
|
66 |
+
import "enum.proto";
|
67 |
+
import "midi_internal.proto";
|
68 |
+
import "feature_extraction.proto";
|
69 |
+
|
70 |
+
package midi;
|
71 |
+
|
72 |
+
/*
|
73 |
+
Specify the minimum or maximum amount of polyphony using these values. Using POLYPHONY_ANY lets the model choose the level of polyphony.
|
74 |
+
*/
|
75 |
+
enum PolyphonyLevel {
|
76 |
+
POLYPHONY_ANY = 0;
|
77 |
+
POLYPHONY_ONE = 1;
|
78 |
+
POLYPHONY_TWO = 2;
|
79 |
+
POLYPHONY_THREE = 3;
|
80 |
+
POLYPHONY_FOUR = 4;
|
81 |
+
POLYPHONY_FIVE = 5;
|
82 |
+
POLYPHONY_SIX = 6;
|
83 |
+
}
|
84 |
+
|
85 |
+
/*
|
86 |
+
Specify the minimum or maximum bounds for note-duration using these values. Using DURATION_ANY lets the model choose the bounds for note duration.
|
87 |
+
*/
|
88 |
+
enum NoteDurationLevel {
|
89 |
+
DURATION_ANY = 0;
|
90 |
+
DURATION_THIRTY_SECOND = 1;
|
91 |
+
DURATION_SIXTEENTH = 2;
|
92 |
+
DURATION_EIGHTH = 3;
|
93 |
+
DURATION_QUARTER = 4;
|
94 |
+
DURATION_HALF = 5;
|
95 |
+
DURATION_WHOLE = 6;
|
96 |
+
}
|
97 |
+
|
98 |
+
/*
|
99 |
+
Specify the minimum or maximum amount of note density using these values. Using DENSITY ANY lets the model choose the level of density.
|
100 |
+
*/
|
101 |
+
enum DensityLevel {
|
102 |
+
DENSITY_ANY = 0;
|
103 |
+
DENSITY_ONE = 1;
|
104 |
+
DENSITY_TWO = 2;
|
105 |
+
DENSITY_THREE = 3;
|
106 |
+
DENSITY_FOUR = 4;
|
107 |
+
DENSITY_FIVE = 5;
|
108 |
+
DENSITY_SIX = 6;
|
109 |
+
DENSITY_SEVEN = 7;
|
110 |
+
DENSITY_EIGHT = 8;
|
111 |
+
DENSITY_NINE = 9;
|
112 |
+
DENSITY_TEN = 10;
|
113 |
+
}
|
114 |
+
|
115 |
+
// specify the levels for bar level onset density
|
116 |
+
enum BarLevelOnsetDensityLevel {
|
117 |
+
BAR_LEVEL_ONSET_DENSITY_ANY = 0;
|
118 |
+
BAR_LEVEL_ONSET_DENSITY_ZERO = 1;
|
119 |
+
BAR_LEVEL_ONSET_DENSITY_ONE = 2;
|
120 |
+
BAR_LEVEL_ONSET_DENSITY_TWO = 3;
|
121 |
+
BAR_LEVEL_ONSET_DENSITY_THREE = 4;
|
122 |
+
BAR_LEVEL_ONSET_DENSITY_FOUR = 5;
|
123 |
+
BAR_LEVEL_ONSET_DENSITY_FIVE = 6;
|
124 |
+
BAR_LEVEL_ONSET_DENSITY_SIX = 7;
|
125 |
+
BAR_LEVEL_ONSET_DENSITY_SEVEN = 8;
|
126 |
+
BAR_LEVEL_ONSET_DENSITY_EIGHT = 9;
|
127 |
+
BAR_LEVEL_ONSET_DENSITY_NINE = 10;
|
128 |
+
BAR_LEVEL_ONSET_DENSITY_TEN = 11;
|
129 |
+
BAR_LEVEL_ONSET_DENSITY_ELEVEN = 12;
|
130 |
+
BAR_LEVEL_ONSET_DENSITY_TWELVE = 13;
|
131 |
+
BAR_LEVEL_ONSET_DENSITY_THIRTEEN = 14;
|
132 |
+
BAR_LEVEL_ONSET_DENSITY_FOURTEEN = 15;
|
133 |
+
BAR_LEVEL_ONSET_DENSITY_FIFTEEN = 16;
|
134 |
+
BAR_LEVEL_ONSET_DENSITY_SIXTEEN = 17;
|
135 |
+
}
|
136 |
+
|
137 |
+
// specify the levels for bar level onset polyphony
|
138 |
+
enum BarLevelOnsetPolyphonyLevel {
|
139 |
+
BAR_LEVEL_ONSET_POLYPHONY_ANY = 0;
|
140 |
+
BAR_LEVEL_ONSET_POLYPHONY_ONE = 1;
|
141 |
+
BAR_LEVEL_ONSET_POLYPHONY_TWO = 2;
|
142 |
+
BAR_LEVEL_ONSET_POLYPHONY_THREE = 3;
|
143 |
+
BAR_LEVEL_ONSET_POLYPHONY_FOUR = 4;
|
144 |
+
BAR_LEVEL_ONSET_POLYPHONY_FIVE = 5;
|
145 |
+
BAR_LEVEL_ONSET_POLYPHONY_SIX = 6;
|
146 |
+
}
|
147 |
+
|
148 |
+
enum SilenceProportionLevel {
|
149 |
+
SILENCE_PROPORTION_LEVEL_ANY = 0;
|
150 |
+
SILENCE_PROPORTION_LEVEL_ONE = 1;
|
151 |
+
SILENCE_PROPORTION_LEVEL_TWO = 2;
|
152 |
+
SILENCE_PROPORTION_LEVEL_THREE = 3;
|
153 |
+
SILENCE_PROPORTION_LEVEL_FOUR = 4;
|
154 |
+
SILENCE_PROPORTION_LEVEL_FIVE = 5;
|
155 |
+
SILENCE_PROPORTION_LEVEL_SIX = 6;
|
156 |
+
SILENCE_PROPORTION_LEVEL_SEVEN = 7;
|
157 |
+
SILENCE_PROPORTION_LEVEL_EIGHT = 8;
|
158 |
+
SILENCE_PROPORTION_LEVEL_NINE = 9;
|
159 |
+
SILENCE_PROPORTION_LEVEL_TEN = 10;
|
160 |
+
}
|
161 |
+
|
162 |
+
enum BooleanLevel {
|
163 |
+
BOOLEAN_ANY = 0;
|
164 |
+
BOOLEAN_FALSE = 1;
|
165 |
+
BOOLEAN_TRUE = 2;
|
166 |
+
}
|
167 |
+
|
168 |
+
/*
|
169 |
+
The Event Message is used to represent a MIDI note onset or offset.
|
170 |
+
*/
|
171 |
+
message Event {
|
172 |
+
/*
|
173 |
+
The time of the event (either a note onset or note offset) relative to the current bar in quantized steps. Currently, most model quantize each quarter note beat into 12 subdivisions. As a result, if the event happens an eighth note after the start of the bar, this value would be 6. If the event occurs three quarter notes after the start of the bar, this value would be 3 * 12 = 36.
|
174 |
+
*/
|
175 |
+
optional int32 time = 1 [(minval) = 0, (maxval) = 1000000];
|
176 |
+
/*
|
177 |
+
The MIDI velocity. This value must be 0 for note off messages.
|
178 |
+
*/
|
179 |
+
optional int32 velocity = 2 [(minval) = 0, (maxval) = 127];
|
180 |
+
/*
|
181 |
+
The MIDI pitch value of on the range [0,128).
|
182 |
+
*/
|
183 |
+
optional int32 pitch = 3 [(minval) = 0, (maxval) = 127];
|
184 |
+
|
185 |
+
optional int32 internal_instrument = 4;
|
186 |
+
optional int32 internal_track_type = 10;
|
187 |
+
optional int32 internal_duration = 11;
|
188 |
+
optional int32 delta = 12;
|
189 |
+
}
|
190 |
+
|
191 |
+
/*
|
192 |
+
The Bar message specifies the events occuring in a bar.
|
193 |
+
*/
|
194 |
+
message Bar {
|
195 |
+
/*
|
196 |
+
A list of integers, which are simply the indices of the messages found in the Piece.events repeated message. Note offsets which occur at the end of the bar (i.e. event.time = 48 with a time signature of 4/4 and piece.resolution of 12) should be included in the current bar rather than the next bar. In other words, no note offsets should even have an event.time = 0, as these note offset events would belong in the previous bar.
|
197 |
+
*/
|
198 |
+
repeated int32 events = 1 [(minval) = 0, (maxval) = 2147483647];
|
199 |
+
/*
|
200 |
+
Numerator for the time-signature of the bar. Note that while time signatures can vary from bar to bar, they cannot vary from track to track. In other words if the second bar in track 0 has a time signature of 4/4, the second bar in track 1 must also have a time signature of 4/4.
|
201 |
+
*/
|
202 |
+
optional int32 ts_numerator = 8 [(minval) = 1, (maxval) = 1000000];
|
203 |
+
/*
|
204 |
+
Denominator for the time-signature of the bar.
|
205 |
+
*/
|
206 |
+
optional int32 ts_denominator = 9 [(minval) = 1, (maxval) = 1000000];
|
207 |
+
|
208 |
+
optional float internal_beat_length = 5;
|
209 |
+
optional bool internal_has_notes = 3;
|
210 |
+
repeated ContinuousFeature internal_feature = 10;
|
211 |
+
repeated BarFeatures internal_features = 11; // why isn't this just called bar features
|
212 |
+
}
|
213 |
+
|
214 |
+
/*
|
215 |
+
The piece message contains a list of bars, and specifies the instrument and track_type.
|
216 |
+
*/
|
217 |
+
message Track {
|
218 |
+
/*
|
219 |
+
A list of bars. Note that each track must have the same number of bars.
|
220 |
+
*/
|
221 |
+
repeated Bar bars = 1;
|
222 |
+
/*
|
223 |
+
The MIDI instrument number for the track.
|
224 |
+
*/
|
225 |
+
optional int32 instrument = 3 [(minval) = 0, (maxval) = 139];
|
226 |
+
/*127 original instruments with drum instrument seperated into 12 individual drum intruments (TR-808) with the rest of the drums in a single intrument*/
|
227 |
+
/*
|
228 |
+
This must be a value in the TRACK_TYPE enum. In most cases, using STANDARAD_TRACK and STANDARD_DRUM_TRACK will suffice to denote a non-drum instrument track and a drum track respectively.
|
229 |
+
*/
|
230 |
+
optional TRACK_TYPE track_type = 5;
|
231 |
+
|
232 |
+
repeated TRACK_TYPE internal_train_types = 6;
|
233 |
+
repeated TrackFeatures internal_features = 7; // why isn't this called track features
|
234 |
+
}
|
235 |
+
|
236 |
+
/*
|
237 |
+
The Piece message specifies the actual musical material in a track-separated event-based format, specifying the note onsets and offsets for each bar in each track.
|
238 |
+
*/
|
239 |
+
message Piece {
|
240 |
+
/*
|
241 |
+
Organizes MIDI events into tracks and bars. In short, each track contains a list of bars, which in turn contains a list of event indices (corresponding to the repeated events message in the Piece.
|
242 |
+
*/
|
243 |
+
repeated Track tracks = 1;
|
244 |
+
/*
|
245 |
+
A list of MIDI events which the tracks and bars reference
|
246 |
+
*/
|
247 |
+
repeated Event events = 2;
|
248 |
+
/*
|
249 |
+
The time resolution used to quantize / discretize musical material. Unless otherwise instructed, this should be set to 12.
|
250 |
+
*/
|
251 |
+
optional int32 resolution = 3 [(minval) = 1, (maxval) = 24];
|
252 |
+
/*
|
253 |
+
Optionally the tempo can be specified. However this is not taken into consideration by the model.
|
254 |
+
*/
|
255 |
+
//optional int32 tempo = 4 [(minval) = 1, (maxval) = 1000000];
|
256 |
+
optional int32 tempo = 4;
|
257 |
+
|
258 |
+
repeated int32 internal_valid_segments = 7;
|
259 |
+
repeated uint32 internal_valid_tracks = 8;
|
260 |
+
optional int32 internal_segment_length = 12;
|
261 |
+
repeated ValidTrack internal_valid_tracks_v2 = 13;
|
262 |
+
repeated GenreData internal_genre_data = 14;
|
263 |
+
optional MetadataLabels internal_metadata_labels = 15;
|
264 |
+
repeated PieceFeatures internal_features = 16;
|
265 |
+
|
266 |
+
optional int32 internal_ticks_per_quarter = 5;
|
267 |
+
optional bool internal_has_time_signatures = 6;
|
268 |
+
}
|
269 |
+
|
270 |
+
/*
|
271 |
+
The StatusBar message specifies per-bar information for generation.
|
272 |
+
*/
|
273 |
+
message StatusBar {
|
274 |
+
optional int32 ts_numerator = 1;
|
275 |
+
optional int32 ts_denominator = 2;
|
276 |
+
optional BarLevelOnsetDensityLevel onset_density = 3;
|
277 |
+
optional BarLevelOnsetPolyphonyLevel onset_polyphony_min = 4;
|
278 |
+
optional BarLevelOnsetPolyphonyLevel onset_polyphony_max = 5;
|
279 |
+
}
|
280 |
+
|
281 |
+
/*
|
282 |
+
The StatusTrack message specifies per-track information for generation.
|
283 |
+
*/
|
284 |
+
message StatusTrack {
|
285 |
+
/*
|
286 |
+
The index of a track in the Piece message. For a track to be seen by the model, it must be referenced by a StatusTrack message via the track_id field. Tracks that are not referenced by a StatusTrack message will not be considered by the model.
|
287 |
+
*/
|
288 |
+
optional int32 track_id = 1 [(minval) = 0, (maxval) = 1000000];
|
289 |
+
/*
|
290 |
+
This must be a value in the TRACK_TYPE enum. This should be equivalent to the TRACK_TYPE specified in the corresponding Track, unless you are giving the model an option to choose either a drum or instrument track. In this case use the STANDARAD_BOTH value here, and the TRACK_TYPE in the piece will be ignored.
|
291 |
+
*/
|
292 |
+
optional TRACK_TYPE track_type = 2;
|
293 |
+
/*
|
294 |
+
This must be a value in the GM_TYPE enum. It specifies the set of possible instruments that the model may choose from. The mapping between GM_TYPE and instrument numbers can be found in src/midigpt_api/enum/gm.h. For example, using midi::GM_TYPE::piano will allow the model to use any piano instrument.
|
295 |
+
*/
|
296 |
+
optional GM_TYPE instrument = 3;
|
297 |
+
/*
|
298 |
+
A list of boolean values which specifies whether a bar is to be generated (true) or conditioned on (false). This must be the same length as the number of bars in the corresponding Track message.
|
299 |
+
*/
|
300 |
+
repeated bool selected_bars = 5;
|
301 |
+
/*
|
302 |
+
Indicates whether or not to use auto-regressive sampling. Note that you can only use auto-regressive sampling when each value in selected_bars is true (i.e. the entire track is being generated). Note that you do not have to use auto-regressive sampling when all selected bars is all true.
|
303 |
+
*/
|
304 |
+
optional bool autoregressive = 6;
|
305 |
+
/*
|
306 |
+
This indicates that the track should be ignored. The model will not be conditioned on this track, and it will in no way effect the generated outcome.
|
307 |
+
*/
|
308 |
+
optional bool ignore = 7;
|
309 |
+
|
310 |
+
optional DensityLevel density = 4;
|
311 |
+
optional PolyphonyLevel min_polyphony_q = 10;
|
312 |
+
optional PolyphonyLevel max_polyphony_q = 11;
|
313 |
+
optional NoteDurationLevel min_note_duration_q = 12;
|
314 |
+
optional NoteDurationLevel max_note_duration_q = 13;
|
315 |
+
|
316 |
+
optional BarLevelOnsetPolyphonyLevel onset_polyphony_min = 20;
|
317 |
+
optional BarLevelOnsetPolyphonyLevel onset_polyphony_max = 21;
|
318 |
+
optional BarLevelOnsetDensityLevel onset_density = 22;
|
319 |
+
optional BarLevelOnsetDensityLevel onset_density_min = 23;
|
320 |
+
optional BarLevelOnsetDensityLevel onset_density_max = 24;
|
321 |
+
optional int32 min_pitch = 25 [(minval) = 0, (maxval) = 127];
|
322 |
+
optional int32 max_pitch = 26 [(minval) = 0, (maxval) = 127];
|
323 |
+
optional GenreMusicmap genre = 28;
|
324 |
+
optional SilenceProportionLevel silence_proportion_min = 29;
|
325 |
+
optional SilenceProportionLevel silence_proportion_max = 30;
|
326 |
+
optional DensityLevel note_density_level = 31;
|
327 |
+
|
328 |
+
optional BooleanLevel contains_note_duration_thirty_second = 36;
|
329 |
+
optional BooleanLevel contains_note_duration_sixteenth = 37;
|
330 |
+
optional BooleanLevel contains_note_duration_eighth = 38;
|
331 |
+
optional BooleanLevel contains_note_duration_quarter = 39;
|
332 |
+
optional BooleanLevel contains_note_duration_half = 40;
|
333 |
+
optional BooleanLevel contains_note_duration_whole = 41;
|
334 |
+
|
335 |
+
/*
|
336 |
+
Sets a hard limit on the polyphony on this track. This is implemented by keeping a record of all the currently sounding notes, and preventing the model from generating note-onset tokens when the limit is reached.
|
337 |
+
*/
|
338 |
+
optional int32 polyphony_hard_limit = 16 [(minval) = 0, (maxval) = 100];
|
339 |
+
/*
|
340 |
+
Allows for the entropy of generation to be adjusted. When temperature=1, the probability distributions output by the model are unaltered. When temperature<1 the probability distribution is increasingly biased towards the most probable tokens. With a very small temperature value this would be equivalent to argmax sampling. When temperature>1 the probability distribution moves towards a random uniform distribution. It is recommended to keep this value close to 1 in most cases.
|
341 |
+
*/
|
342 |
+
optional float temperature = 17 [(fminval) = 0.5, (fmaxval) = 2.0];
|
343 |
+
|
344 |
+
|
345 |
+
repeated int32 internal_ts_numerators = 14;
|
346 |
+
repeated int32 internal_ts_denominators = 15;
|
347 |
+
optional string internal_genre = 9;
|
348 |
+
|
349 |
+
repeated StatusBar bars = 19;
|
350 |
+
}
|
351 |
+
|
352 |
+
/*
|
353 |
+
The Status message specifies which bars or tracks are to be generated/conditioned on, and provides extra information about conditioning such as instrument, density, polyphony and note-duration.
|
354 |
+
*/
|
355 |
+
message Status {
|
356 |
+
repeated StatusTrack tracks = 1;
|
357 |
+
|
358 |
+
/*
|
359 |
+
For microtiming generation, last sampling step must be decoded using the delta resolution
|
360 |
+
*/
|
361 |
+
optional bool decode_final = 2;
|
362 |
+
optional bool full_resolution = 3;
|
363 |
+
}
|
364 |
+
|
365 |
+
/*
|
366 |
+
The SampleParam message specifies hyper-parameters for generation.
|
367 |
+
*/
|
368 |
+
message HyperParam {
|
369 |
+
/*
|
370 |
+
For multi-step generation (typically employed when the entire piece is too large to be considered by the model simultaneously) this parameter specifies the number of tracks that are generated in each step.
|
371 |
+
*/
|
372 |
+
optional int32 tracks_per_step = 1 [(minval) = 1, (maxval) = 12];
|
373 |
+
/*
|
374 |
+
For multi-step generation this parameter specifies the number of bars that are generated in each step. This value should be set in relation to model_dim. If bars_per_step = model_dim, then there will be no horizontal conditioning, which will typically produce inferior results. A good rule of thumb is to use bars_per_step == model_dim / 2.
|
375 |
+
*/
|
376 |
+
optional int32 bars_per_step = 2 [(minval) = 1, (maxval) = 8];
|
377 |
+
/*
|
378 |
+
The size of the model. In most cases this will be 4.
|
379 |
+
*/
|
380 |
+
optional int32 model_dim = 3 [(minval) = 1, (maxval) = 8];
|
381 |
+
/*
|
382 |
+
The percentage of the selected material (selected bars in the Status message) that will be generated.
|
383 |
+
*/
|
384 |
+
optional int32 percentage = 5 [(minval) = 1, (maxval) = 100];
|
385 |
+
/*
|
386 |
+
The number of outputs to be generated. Currently we only support batch_size=1.
|
387 |
+
With multi-step sampling its is likely more efficient to simply make several calls in series.
|
388 |
+
*/
|
389 |
+
optional int32 batch_size = 7 [(minval) = 1, (maxval) = 1];
|
390 |
+
/*
|
391 |
+
Allows for the entropy of generation to be adjusted. When temperature=1, the probability distributions output by the model are unaltered. When temperature<1 the probability distribution is increasingly biased towards the most probable tokens. With a very small temperature value this would be equivalent to argmax sampling. When temperature>1 the probability distribution moves towards a random uniform distribution. It is recommended to keep this value close to 1 in most cases.
|
392 |
+
*/
|
393 |
+
optional float temperature = 6 [(fminval) = 0.5, (fmaxval) = 2.0];
|
394 |
+
/*
|
395 |
+
This parameter turns on and off per-track temperature control
|
396 |
+
*/
|
397 |
+
optional bool use_per_track_temperature = 17;
|
398 |
+
/*
|
399 |
+
The max number of tokens to generate before terminating generation. Can be used to avoid memory overload. When this value is set to zero it is ignored, and no limitations are set of the number of generated tokens.
|
400 |
+
*/
|
401 |
+
optional int32 max_steps = 13 [(minval) = 0, (maxval) = 2048];
|
402 |
+
|
403 |
+
optional int32 polyphony_hard_limit = 14 [(minval) = 0, (maxval) = 100];
|
404 |
+
/*
|
405 |
+
When shuffle=true the generation steps are randomly ordered. For obvious reasons, auto-regressive sampling cannot be used with shuffle=true, as it would cease to be auto-regressive.
|
406 |
+
*/
|
407 |
+
optional bool shuffle = 4;
|
408 |
+
/*
|
409 |
+
Mainly for debugging purposes.
|
410 |
+
*/
|
411 |
+
optional bool verbose = 8;
|
412 |
+
/*
|
413 |
+
The path to the ckpt, which should either be an absolute path or relative to the executable.
|
414 |
+
*/
|
415 |
+
optional string ckpt = 9;
|
416 |
+
/*
|
417 |
+
Control over probability of masking top k
|
418 |
+
*/
|
419 |
+
optional float mask_top_k = 10 [(fminval) = 0, (fmaxval) = 1.];
|
420 |
+
/*
|
421 |
+
Control stochastic seed for reproducability
|
422 |
+
*/
|
423 |
+
optional int32 sampling_seed = 11;
|
424 |
+
|
425 |
+
optional bool internal_skip_preprocess = 12;
|
426 |
+
optional bool internal_disable_masking = 16;
|
427 |
+
|
428 |
+
}
|
429 |
+
|
libraries/protobuf/src/midi_internal.proto
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto2";
|
2 |
+
|
3 |
+
import "enum.proto";
|
4 |
+
import "google/protobuf/descriptor.proto";
|
5 |
+
|
6 |
+
package midi;
|
7 |
+
|
8 |
+
extend google.protobuf.FieldOptions {
|
9 |
+
optional int32 maxval = 50001;
|
10 |
+
optional int32 minval = 50002;
|
11 |
+
optional float fmaxval = 50003;
|
12 |
+
optional float fminval = 50004;
|
13 |
+
}
|
14 |
+
|
15 |
+
message ContinuousFeature {
|
16 |
+
optional float av_polyphony = 1;
|
17 |
+
optional float note_duration = 3;
|
18 |
+
optional float note_duration_norm = 4;
|
19 |
+
}
|
20 |
+
|
21 |
+
message BarFeatures {
|
22 |
+
optional int32 onset_density = 1;
|
23 |
+
optional int32 onset_polyphony_min = 2;
|
24 |
+
optional int32 onset_polyphony_max = 3;
|
25 |
+
}
|
26 |
+
|
27 |
+
message TrackLevelAttributeControlDistributions {
|
28 |
+
repeated int32 polyphony_quantile = 1;
|
29 |
+
repeated int32 note_duration_quantile = 2;
|
30 |
+
repeated int32 note_density = 3;
|
31 |
+
repeated int32 onset_polyphony = 4;
|
32 |
+
repeated int32 onset_density = 5;
|
33 |
+
repeated int32 note_duration = 8;
|
34 |
+
}
|
35 |
+
|
36 |
+
message TrackFeatures {
|
37 |
+
optional int32 min_pitch = 1;
|
38 |
+
optional int32 max_pitch = 2;
|
39 |
+
optional float av_polyphony = 3;
|
40 |
+
optional int32 note_density = 4;
|
41 |
+
optional int32 note_density_v2 = 5;
|
42 |
+
optional int32 max_polyphony = 6;
|
43 |
+
optional bool should_prune = 7;
|
44 |
+
optional int32 order = 8;
|
45 |
+
optional float note_duration = 9;
|
46 |
+
optional string genre_str = 10;
|
47 |
+
optional int32 min_polyphony_q = 11;
|
48 |
+
optional int32 max_polyphony_q = 12;
|
49 |
+
optional int32 min_note_duration_q = 13;
|
50 |
+
optional int32 max_note_duration_q = 14;
|
51 |
+
repeated int32 polyphony_distribution = 15;
|
52 |
+
optional float note_density_value = 16;
|
53 |
+
|
54 |
+
optional int32 min_polyphony_hard = 18;
|
55 |
+
optional int32 max_polyphony_hard = 19;
|
56 |
+
optional int32 min_note_duration_hard = 20;
|
57 |
+
optional int32 max_note_duration_hard = 21;
|
58 |
+
|
59 |
+
optional int32 onset_polyphony_min = 24;
|
60 |
+
optional int32 onset_polyphony_max = 25;
|
61 |
+
optional int32 onset_density = 26;
|
62 |
+
optional int32 onset_density_min = 27;
|
63 |
+
optional int32 onset_density_max = 28;
|
64 |
+
repeated int32 duration_distribution = 29;
|
65 |
+
|
66 |
+
optional int32 genre = 32;
|
67 |
+
optional int32 note_density_level = 35;
|
68 |
+
|
69 |
+
optional int32 contains_note_duration_thirty_second = 40;
|
70 |
+
optional int32 contains_note_duration_sixteenth = 41;
|
71 |
+
optional int32 contains_note_duration_eighth = 42;
|
72 |
+
optional int32 contains_note_duration_quarter = 43;
|
73 |
+
optional int32 contains_note_duration_half = 44;
|
74 |
+
optional int32 contains_note_duration_whole = 45;
|
75 |
+
|
76 |
+
optional TrackLevelAttributeControlDistributions attribute_control_distributions = 30; // store them all here
|
77 |
+
|
78 |
+
}
|
79 |
+
|
80 |
+
message PieceFeatures {
|
81 |
+
optional string genre = 1;
|
82 |
+
}
|
83 |
+
|
84 |
+
message Note {
|
85 |
+
optional int32 start = 1;
|
86 |
+
optional int32 end = 2;
|
87 |
+
optional int32 pitch = 3;
|
88 |
+
optional int32 tick_delta = 4;
|
89 |
+
}
|
90 |
+
|
91 |
+
message ValidTrack {
|
92 |
+
repeated int32 tracks = 1;
|
93 |
+
}
|
94 |
+
|
95 |
+
message Item {
|
96 |
+
optional uint64 start = 1;
|
97 |
+
optional uint64 end = 2;
|
98 |
+
optional uint64 src_size = 3;
|
99 |
+
}
|
100 |
+
|
101 |
+
message Dataset {
|
102 |
+
repeated Item train = 1;
|
103 |
+
repeated Item valid = 2;
|
104 |
+
repeated Item test = 3;
|
105 |
+
}
|
106 |
+
|
107 |
+
message ModelMetadata {
|
108 |
+
optional string encoder = 1;
|
109 |
+
optional int32 num_layers = 2;
|
110 |
+
optional int32 num_heads = 3;
|
111 |
+
optional int32 num_hidden = 4;
|
112 |
+
optional int32 model_dim = 5;
|
113 |
+
optional bool new_state = 6;
|
114 |
+
}
|
115 |
+
|
116 |
+
message GenreData {
|
117 |
+
optional string discogs = 1;
|
118 |
+
optional string lastfm = 2;
|
119 |
+
optional string tagtraum = 3;
|
120 |
+
}
|
121 |
+
|
122 |
+
message MetadataLabels {
|
123 |
+
optional GenreMusicmap genre = 1;
|
124 |
+
optional int32 nomml = 6;
|
125 |
+
}
|
libraries/protobuf/src/track_type.proto
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
syntax = "proto2";
|
2 |
+
|
3 |
+
package midi;
|
4 |
+
|
5 |
+
enum TRACK_TYPE {
|
6 |
+
AUX_DRUM_TRACK = 8;
|
7 |
+
AUX_INST_TRACK = 9;
|
8 |
+
STANDARD_TRACK = 10;
|
9 |
+
STANDARD_DRUM_TRACK = 11;
|
10 |
+
STANDARD_BOTH = 12;
|
11 |
+
NUM_TRACK_TYPES = 16;
|
12 |
+
}
|
libraries/pybind11
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 5ccb9e412d8974e8eeac1b061d9077ac0bd365e1
|
libraries/torch/CMakeLists.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 3.8)
|
2 |
+
|
3 |
+
project(midigpt_torch)
|
4 |
+
|
5 |
+
set(SRCS
|
6 |
+
src/torch_library.cpp
|
7 |
+
"include/torch_library.h")
|
8 |
+
|
9 |
+
set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../libtorch/")
|
10 |
+
find_package(Torch REQUIRED)
|
11 |
+
|
12 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
13 |
+
|
14 |
+
add_library(midigpt_torch
|
15 |
+
${SRCS})
|
16 |
+
|
17 |
+
target_link_libraries(midigpt_torch PRIVATE "${TORCH_LIBRARIES}")
|
18 |
+
|
19 |
+
if (MSVC)
|
20 |
+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
|
21 |
+
add_custom_command(TARGET example-app
|
22 |
+
POST_BUILD
|
23 |
+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
24 |
+
${TORCH_DLLS}
|
25 |
+
$<TARGET_FILE_DIR:example-app>)
|
26 |
+
endif (MSVC)
|
libraries/torch/CMakeSettings.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"configurations": [
|
3 |
+
{
|
4 |
+
"name": "x64-Debug",
|
5 |
+
"generator": "Ninja",
|
6 |
+
"configurationType": "Debug",
|
7 |
+
"inheritEnvironments": [ "msvc_x64_x64" ],
|
8 |
+
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
9 |
+
"installRoot": "${projectDir}\\out\\install\\${name}",
|
10 |
+
"cmakeCommandArgs": "",
|
11 |
+
"buildCommandArgs": "",
|
12 |
+
"ctestCommandArgs": ""
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"name": "WSL-GCC-Debug",
|
16 |
+
"generator": "Ninja",
|
17 |
+
"configurationType": "Debug",
|
18 |
+
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
19 |
+
"installRoot": "${projectDir}\\out\\install\\${name}",
|
20 |
+
"cmakeExecutable": "cmake",
|
21 |
+
"cmakeCommandArgs": "",
|
22 |
+
"buildCommandArgs": "",
|
23 |
+
"ctestCommandArgs": "",
|
24 |
+
"inheritEnvironments": [ "linux_x64" ],
|
25 |
+
"wslPath": "${defaultWSLPath}",
|
26 |
+
"variables": []
|
27 |
+
}
|
28 |
+
]
|
29 |
+
}
|
libraries/torch/include/torch_library.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
namespace midigptTorch {
|
4 |
+
void testMmmTorch();
|
5 |
+
}
|
libraries/torch/src/torch_library.cpp
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <iostream>
|
2 |
+
#include "../../libtorch/include/torch/csrc/api/include/torch/torch.h"
|
3 |
+
|
4 |
+
namespace midigptTorch {
|
5 |
+
void testMmmTorch() {
|
6 |
+
std::cout << "midigptTorch test function called" << std::endl;
|
7 |
+
torch::Tensor tensor = torch::rand({ 2, 3 });
|
8 |
+
std::cout << tensor << std::endl;
|
9 |
+
}
|
10 |
+
}
|
midigpt_setup_helper.sh
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
clone=no
|
4 |
+
replace=no
|
5 |
+
test_train=no
|
6 |
+
inference=no
|
7 |
+
mac=no
|
8 |
+
python_path=/Library/Frameworks/Python.framework/Versions/3.8/bin
|
9 |
+
og=no
|
10 |
+
|
11 |
+
usage="
|
12 |
+
$(basename "$0") [-h] [-n] [-c] [-i] [-m] [-r] [-d
|
13 |
+
directory]
|
14 |
+
-- Script for setting up and testing the MIDI-GPT repository
|
15 |
+
|
16 |
+
where:
|
17 |
+
-h Show this help text
|
18 |
+
-n Test the training script imports
|
19 |
+
-c Clone the github MIDI-GPT repository
|
20 |
+
-i If you wish to setup repository for inference
|
21 |
+
-d Provide directory name where repo is/will be cloned
|
22 |
+
-r Replace directory if already exists
|
23 |
+
-m If on MacOS CPU
|
24 |
+
"
|
25 |
+
|
26 |
+
|
27 |
+
OPTSTRING=":cnhiomrd:k:"
|
28 |
+
|
29 |
+
while getopts ${OPTSTRING} opt; do
|
30 |
+
case ${opt} in
|
31 |
+
h)
|
32 |
+
echo "${usage}"
|
33 |
+
exit 0
|
34 |
+
;;
|
35 |
+
i)
|
36 |
+
inference=yes
|
37 |
+
;;
|
38 |
+
n)
|
39 |
+
test_train=yes
|
40 |
+
;;
|
41 |
+
m)
|
42 |
+
mac=yes
|
43 |
+
;;
|
44 |
+
r)
|
45 |
+
replace=yes
|
46 |
+
;;
|
47 |
+
c)
|
48 |
+
clone=yes
|
49 |
+
;;
|
50 |
+
d)
|
51 |
+
repo=${OPTARG}
|
52 |
+
;;
|
53 |
+
:)
|
54 |
+
echo "Option -${OPTARG} requires an argument"
|
55 |
+
exit 1
|
56 |
+
;;
|
57 |
+
?)
|
58 |
+
echo "Invalid option: -${OPTARG}"
|
59 |
+
exit 1
|
60 |
+
;;
|
61 |
+
esac
|
62 |
+
done
|
63 |
+
|
64 |
+
if test "${clone}" = "yes"
|
65 |
+
then
|
66 |
+
echo "Cloning MIDI-GPT"
|
67 |
+
fi
|
68 |
+
|
69 |
+
echo "In directory: ${repo}"
|
70 |
+
|
71 |
+
if test "${replace}" = "yes"
|
72 |
+
then
|
73 |
+
if [[ -d ${repo} ]]
|
74 |
+
then
|
75 |
+
echo "Directory ${repo} already exists, removing it"
|
76 |
+
rm -rf ${repo}
|
77 |
+
fi
|
78 |
+
fi
|
79 |
+
|
80 |
+
mkdir -p ${repo}
|
81 |
+
cd ${repo}
|
82 |
+
|
83 |
+
echo "Loading modules"
|
84 |
+
|
85 |
+
if test "${clone}" = "yes"
|
86 |
+
then
|
87 |
+
if [[ -d MIDI-GPT ]] || [[ -d ENV ]]
|
88 |
+
then
|
89 |
+
echo "MIDI-GPT or ENV directories already exist"
|
90 |
+
exit 1
|
91 |
+
fi
|
92 |
+
if test "${og}" = "yes"
|
93 |
+
then
|
94 |
+
{
|
95 |
+
git clone https://www.github.com/Metacreation-Lab/MIDI-GPT.git
|
96 |
+
|
97 |
+
} || {
|
98 |
+
echo "Cloning failed"
|
99 |
+
exit 1
|
100 |
+
}
|
101 |
+
else
|
102 |
+
{
|
103 |
+
git clone https://www.github.com/Metacreation-Lab/MIDI-GPT.git
|
104 |
+
|
105 |
+
} || {
|
106 |
+
echo "Cloning failed"
|
107 |
+
exit 1
|
108 |
+
}
|
109 |
+
fi
|
110 |
+
|
111 |
+
${python_path}/virtualenv --no-download ./ENV
|
112 |
+
|
113 |
+
else
|
114 |
+
if ! [[ -d MIDI-GPT ]]
|
115 |
+
then
|
116 |
+
echo "MIDI-GPT doesn't exist, try cloning the repository with the -c option"
|
117 |
+
exit 1
|
118 |
+
fi
|
119 |
+
fi
|
120 |
+
|
121 |
+
{
|
122 |
+
source ./ENV/bin/activate
|
123 |
+
} || {
|
124 |
+
echo "ENV virtual environment doesn't exist"
|
125 |
+
exit 1
|
126 |
+
}
|
127 |
+
|
128 |
+
echo "pip installs"
|
129 |
+
|
130 |
+
pip install --no-index --upgrade pip
|
131 |
+
pip install torch==1.13.0
|
132 |
+
pip install transformers==4.26.1
|
133 |
+
|
134 |
+
cd MIDI-GPT
|
135 |
+
if test "${og}" = "no"
|
136 |
+
then
|
137 |
+
git checkout main
|
138 |
+
fi
|
139 |
+
|
140 |
+
echo "Starting python library build"
|
141 |
+
|
142 |
+
{ if test "${inference}" = "yes"
|
143 |
+
then
|
144 |
+
echo "Building for inference"
|
145 |
+
if test "${mac}" = "yes"
|
146 |
+
then
|
147 |
+
echo "On MacOS CPU"
|
148 |
+
bash create_python_library.sh --mac_os
|
149 |
+
else
|
150 |
+
echo "On Compute Canada"
|
151 |
+
bash create_python_library.sh --compute_canada
|
152 |
+
fi
|
153 |
+
else
|
154 |
+
echo "Building for training only"
|
155 |
+
bash create_python_library.sh --test_build --compute_canada --no_torch
|
156 |
+
fi
|
157 |
+
} || {
|
158 |
+
echo "Build failed"
|
159 |
+
exit 1
|
160 |
+
}
|
161 |
+
|
162 |
+
if test "${test_train}" = "yes"
|
163 |
+
then
|
164 |
+
|
165 |
+
cd ../
|
166 |
+
|
167 |
+
deactivate
|
168 |
+
|
169 |
+
echo "Activating environment"
|
170 |
+
|
171 |
+
source $PWD/venv/bin/activate
|
172 |
+
cd $PWD/MIDI-GPT/python_scripts
|
173 |
+
|
174 |
+
echo "Testing training script"
|
175 |
+
|
176 |
+
python3 -c "import train"
|
177 |
+
|
178 |
+
echo "Import tests done"
|
179 |
+
|
180 |
+
fi
|
181 |
+
|
182 |
+
echo "Finished"
|
models/model.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2370946e0b45c440972ef3099be17f27cd7157b7a40ce9fd24851f2c855e4d85
|
3 |
+
size 77523617
|
pip_requirements/common_requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pretty_midi
|
pip_requirements/create_dataset_requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
attrs==23.1.0+computecanada
|
2 |
+
jsonlines==3.1.0+computecanada
|
3 |
+
numpy==1.24.2+computecanada
|
4 |
+
tqdm==4.46.1
|
pip_requirements/inference_requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
torch==2.0.0
|
pip_requirements/train_requirements.txt
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anyio==3.6.2+computecanada
|
2 |
+
arff==0.9+computecanada
|
3 |
+
argon2-cffi==21.3.0+computecanada
|
4 |
+
argon2-cffi-bindings==21.2.0+computecanada
|
5 |
+
asttokens==2.2.1+computecanada
|
6 |
+
async-generator==1.10+computecanada
|
7 |
+
attrs==23.1.0+computecanada
|
8 |
+
backcall==0.2.0+computecanada
|
9 |
+
backports-abc==0.5+computecanada
|
10 |
+
backports-shutil-get-terminal-size==1.0.0+computecanada
|
11 |
+
bcrypt==4.0.1+computecanada
|
12 |
+
beautifulsoup4==4.11.2+computecanada
|
13 |
+
bitstring==4.0.1+computecanada
|
14 |
+
bleach==6.0.0+computecanada
|
15 |
+
certifi==2022.12.7+computecanada
|
16 |
+
cffi==1.15.1+computecanada
|
17 |
+
chardet==5.1.0+computecanada
|
18 |
+
charset-normalizer==3.0.1+computecanada
|
19 |
+
comm==0.1.2+computecanada
|
20 |
+
contourpy==1.0.7+computecanada
|
21 |
+
cryptography==39.0.1+computecanada
|
22 |
+
cycler==0.11.0+computecanada
|
23 |
+
Cython==0.29.33+computecanada
|
24 |
+
deap==1.3.3+computecanada
|
25 |
+
debugpy==1.6.6+computecanada
|
26 |
+
decorator==5.1.1+computecanada
|
27 |
+
defusedxml==0.7.1+computecanada
|
28 |
+
dnspython==2.3.0+computecanada
|
29 |
+
ecdsa==0.18.0+computecanada
|
30 |
+
entrypoints==0.4+computecanada
|
31 |
+
executing==1.2.0+computecanada
|
32 |
+
fastjsonschema==2.16.2+computecanada
|
33 |
+
filelock==3.12.2
|
34 |
+
fonttools==4.38.0+computecanada
|
35 |
+
fsspec==2023.6.0
|
36 |
+
funcsigs==1.0.2+computecanada
|
37 |
+
huggingface-hub==0.15.1
|
38 |
+
idna==3.4+computecanada
|
39 |
+
importlib-metadata==5.2.0+computecanada
|
40 |
+
importlib-resources==5.12.0+computecanada
|
41 |
+
ipykernel==6.21.2+computecanada
|
42 |
+
ipython==8.10.0+computecanada
|
43 |
+
ipython-genutils==0.2.0+computecanada
|
44 |
+
ipywidgets==8.0.4+computecanada
|
45 |
+
jedi==0.18.2+computecanada
|
46 |
+
Jinja2==3.1.2+computecanada
|
47 |
+
jsonlines==3.1.0+computecanada
|
48 |
+
jsonschema==4.17.3+computecanada
|
49 |
+
jupyter-client==8.0.3+computecanada
|
50 |
+
jupyter-core==5.2.0+computecanada
|
51 |
+
jupyter-events==0.6.3+computecanada
|
52 |
+
jupyter-server==2.3.0+computecanada
|
53 |
+
jupyter-server-terminals==0.4.4+computecanada
|
54 |
+
jupyterlab-pygments==0.2.2+computecanada
|
55 |
+
jupyterlab-widgets==3.0.5+computecanada
|
56 |
+
kiwisolver==1.4.4+computecanada
|
57 |
+
lockfile==0.12.2+computecanada
|
58 |
+
MarkupSafe==2.1.2+computecanada
|
59 |
+
matplotlib==3.7.0+computecanada
|
60 |
+
matplotlib-inline==0.1.6+computecanada
|
61 |
+
mistune==2.0.5+computecanada
|
62 |
+
mock==5.0.1+computecanada
|
63 |
+
mpmath==1.2.1+computecanada
|
64 |
+
nbclassic==0.5.2+computecanada
|
65 |
+
nbclient==0.7.2+computecanada
|
66 |
+
nbconvert==7.2.9+computecanada
|
67 |
+
nbformat==5.7.3+computecanada
|
68 |
+
nest-asyncio==1.5.6+computecanada
|
69 |
+
netaddr==0.8.0+computecanada
|
70 |
+
netifaces==0.11.0+computecanada
|
71 |
+
nose==1.3.7+computecanada
|
72 |
+
notebook==6.5.2+computecanada
|
73 |
+
notebook-shim==0.2.2+computecanada
|
74 |
+
numpy==1.24.2+computecanada
|
75 |
+
packaging==23.0+computecanada
|
76 |
+
pandas==1.5.3+computecanada
|
77 |
+
pandocfilters==1.5.0+computecanada
|
78 |
+
paramiko==3.0.0+computecanada
|
79 |
+
parso==0.8.3+computecanada
|
80 |
+
path==16.6.0+computecanada
|
81 |
+
path.py==12.5.0+computecanada
|
82 |
+
pathlib2==2.3.7+computecanada
|
83 |
+
paycheck==1.0.2+computecanada
|
84 |
+
pbr==5.11.1+computecanada
|
85 |
+
pexpect==4.8.0+computecanada
|
86 |
+
pickleshare==0.7.5+computecanada
|
87 |
+
Pillow==9.4.0+computecanada
|
88 |
+
pkgutil-resolve-name==1.3.10+computecanada
|
89 |
+
platformdirs==2.5.2+computecanada
|
90 |
+
prometheus-client==0.16.0+computecanada
|
91 |
+
prompt-toolkit==3.0.37+computecanada
|
92 |
+
protobuf==4.23.3
|
93 |
+
psutil==5.9.4+computecanada
|
94 |
+
ptyprocess==0.7.0+computecanada
|
95 |
+
pure-eval==0.2.2+computecanada
|
96 |
+
pycparser==2.21+computecanada
|
97 |
+
Pygments==2.14.0+computecanada
|
98 |
+
PyNaCl==1.5.0+computecanada
|
99 |
+
pyparsing==3.0.9+computecanada
|
100 |
+
pyrsistent==0.19.3+computecanada
|
101 |
+
python-dateutil==2.8.2+computecanada
|
102 |
+
python-json-logger==2.0.7+computecanada
|
103 |
+
pytz==2022.7.1+computecanada
|
104 |
+
PyYAML==6.0+computecanada
|
105 |
+
pyzmq==25.0.0+computecanada
|
106 |
+
regex==2022.10.31+computecanada
|
107 |
+
requests==2.28.2+computecanada
|
108 |
+
rfc3339-validator==0.1.4+computecanada
|
109 |
+
rfc3986-validator==0.1.1+computecanada
|
110 |
+
scipy==1.10.1+computecanada
|
111 |
+
send2trash==1.8.0+computecanada
|
112 |
+
simplegeneric==0.8.1+computecanada
|
113 |
+
singledispatch==4.0.0+computecanada
|
114 |
+
six==1.16.0+computecanada
|
115 |
+
sniffio==1.3.0+computecanada
|
116 |
+
soupsieve==2.4+computecanada
|
117 |
+
stack-data==0.6.2+computecanada
|
118 |
+
sympy==1.11.1+computecanada
|
119 |
+
tensorboardX==2.2+computecanada
|
120 |
+
terminado==0.17.1+computecanada
|
121 |
+
testpath==0.6.0+computecanada
|
122 |
+
tinycss2==1.2.1+computecanada
|
123 |
+
tokenizers==0.13.2+computecanada
|
124 |
+
torch==1.13.1+computecanada
|
125 |
+
tornado==6.2+computecanada
|
126 |
+
tqdm==4.46.1
|
127 |
+
traitlets==5.9.0+computecanada
|
128 |
+
transformers==4.26.1+computecanada
|
129 |
+
typing-extensions==4.5.0+computecanada
|
130 |
+
urllib3==1.26.14+computecanada
|
131 |
+
wcwidth==0.2.6+computecanada
|
132 |
+
webencodings==0.5.1+computecanada
|
133 |
+
websocket-client==1.5.1+computecanada
|
134 |
+
widgetsnbextension==4.0.5+computecanada
|
135 |
+
zipp==3.14.0+computecanada
|
python_scripts/config/bert.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"hidden_size" : 512,
|
3 |
+
"num_hidden_layers" : 6,
|
4 |
+
"nuim_attention_heads" : 8,
|
5 |
+
"intermediate_size" : 2048,
|
6 |
+
"max_position_embeddings" : 2048
|
7 |
+
}
|
python_scripts/config/bert_tiny.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"hidden_size" : 128,
|
3 |
+
"num_hidden_layers" : 2,
|
4 |
+
"num_attention_heads" : 2,
|
5 |
+
"intermediate_size" : 128,
|
6 |
+
"max_position_embeddings" : 2048
|
7 |
+
}
|
python_scripts/config/gpt2.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_positions": 2048,
|
3 |
+
"n_ctx": 2048,
|
4 |
+
"n_layer": 6,
|
5 |
+
"n_head": 8,
|
6 |
+
"n_embd": 512
|
7 |
+
}
|
python_scripts/config/gpt2_tiny.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"n_positions": 2048,
|
3 |
+
"n_ctx": 2048,
|
4 |
+
"n_layer": 2,
|
5 |
+
"n_head": 2,
|
6 |
+
"n_embd": 128
|
7 |
+
}
|
python_scripts/convert.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# take a trained pytorch model and convert it
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
|
5 |
+
print( os.path.dirname(os.getcwd()) + "/python_lib" )
|
6 |
+
import midigpt
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.quantization
|
12 |
+
from transformers import GPT2LMHeadModel, GPT2Config
|
13 |
+
from transformers.modeling_utils import Conv1D
|
14 |
+
|
15 |
+
from custom_models import *
|
16 |
+
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
class QuantWrapper(nn.Module):
|
20 |
+
def __init__(self, module):
|
21 |
+
super(QuantWrapper, self).__init__()
|
22 |
+
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
|
23 |
+
self.add_module('quant', torch.quantization.QuantStub(qconfig))
|
24 |
+
self.add_module('dequant', torch.quantization.DeQuantStub())
|
25 |
+
self.add_module('module', module)
|
26 |
+
self.train(module.training)
|
27 |
+
|
28 |
+
def forward(self, X, P):
|
29 |
+
X = self.quant(X)
|
30 |
+
P = self.quant(P)
|
31 |
+
O = self.module(X,P)
|
32 |
+
return self.dequant(O)
|
33 |
+
|
34 |
+
def _conv1d_to_linear(module):
|
35 |
+
in_size, out_size = module.weight.shape
|
36 |
+
linear = torch.nn.Linear(in_size, out_size)
|
37 |
+
linear.weight.data = module.weight.data.T.contiguous()
|
38 |
+
linear.bias.data = module.bias.data
|
39 |
+
return linear
|
40 |
+
|
41 |
+
def conv1d_to_linear(model):
|
42 |
+
for name in list(model._modules):
|
43 |
+
module = model._modules[name]
|
44 |
+
if isinstance(module, Conv1D):
|
45 |
+
linear = _conv1d_to_linear(module)
|
46 |
+
model._modules[name] = linear
|
47 |
+
else:
|
48 |
+
conv1d_to_linear(module)
|
49 |
+
|
50 |
+
def score_model(model):
|
51 |
+
targets = np.load("target.npz")["data"]
|
52 |
+
|
53 |
+
|
54 |
+
def time_model(model):
|
55 |
+
start = time.time()
|
56 |
+
pkv = None
|
57 |
+
for _ in range(1000):
|
58 |
+
input_ids = torch.ones(1,1).type(torch.LongTensor)
|
59 |
+
outputs = model(input_ids, past_key_values=pkv)
|
60 |
+
pkv = outputs[1]
|
61 |
+
print("BATCH TIME : {}".format(time.time() - start))
|
62 |
+
|
63 |
+
def print_size_of_model(model):
|
64 |
+
import os
|
65 |
+
torch.save(model.state_dict(), "temp.p")
|
66 |
+
print('Size (MB):', os.path.getsize("temp.p")/1e6)
|
67 |
+
os.remove('temp.p')
|
68 |
+
|
69 |
+
def quantize_model(model):
|
70 |
+
conv1d_to_linear(model)
|
71 |
+
model = torch.quantization.quantize_dynamic(
|
72 |
+
model, {torch.nn.Linear}, dtype=torch.qint8)
|
73 |
+
return model
|
74 |
+
|
75 |
+
def static_quantize_model(model):
|
76 |
+
conv1d_to_linear(model)
|
77 |
+
model.qconfig = torch.quantization.default_qconfig
|
78 |
+
torch.quantization.prepare(model, inplace=True)
|
79 |
+
torch.quantization.convert(model, inplace=True)
|
80 |
+
|
81 |
+
return model
|
82 |
+
|
83 |
+
def prune_model(model):
|
84 |
+
import torch.nn.utils.prune as prune
|
85 |
+
|
86 |
+
conv1d_to_linear(model)
|
87 |
+
parameters_to_prune = []
|
88 |
+
for _,module in model.named_modules():
|
89 |
+
if isinstance(module, torch.nn.Linear):
|
90 |
+
prune.l1_unstructured(module, name="weight", amount=.8)
|
91 |
+
prune.remove(module, "weight")
|
92 |
+
|
93 |
+
for _,submodule in module.named_modules():
|
94 |
+
if isinstance(submodule, torch.nn.Linear):
|
95 |
+
prune.l1_unstructured(submodule, name="weight", amount=.8)
|
96 |
+
prune.remove(submodule, "weight")
|
97 |
+
|
98 |
+
return model
|
99 |
+
|
100 |
+
def inject_metadata(path, metadata_path, encoder, new_state):
|
101 |
+
model = torch.jit.load(path)
|
102 |
+
with open(metadata_path, "r") as f:
|
103 |
+
metadata = json.load(f)
|
104 |
+
metadata["encoder"] = encoder
|
105 |
+
metadata["new_state"] = new_state
|
106 |
+
extra_files = torch._C.ExtraFilesMap()
|
107 |
+
extra_files['metadata.json'] = json.dumps(metadata)
|
108 |
+
out_path = os.path.splitext(path)[0] + "_WMETA.pt"
|
109 |
+
torch.jit.save(model, out_path, _extra_files=extra_files)
|
110 |
+
|
111 |
+
def convert(model, path, quantize=False, prune=False, force=False, control=False, ckpt_path=None, encoderX=None):
|
112 |
+
if not os.path.exists(path) or force:
|
113 |
+
model.eval()
|
114 |
+
if quantize:
|
115 |
+
model = quantize_model(model)
|
116 |
+
if prune:
|
117 |
+
model = prune_model(model)
|
118 |
+
print_size_of_model(model)
|
119 |
+
example_input = torch.zeros(1,300).type(torch.LongTensor)
|
120 |
+
example_control = torch.zeros(1,300,3).type(torch.FloatTensor)
|
121 |
+
if control:
|
122 |
+
outputs = model(input_ids=example_input, control_ids=example_control, past_key_values=None)
|
123 |
+
print(len(outputs[1]))
|
124 |
+
traced_script_module = torch.jit.trace(model, [example_input,example_control,outputs[1]])
|
125 |
+
else:
|
126 |
+
outputs = model(input_ids=example_input)
|
127 |
+
traced_script_module = torch.jit.trace(model, [example_input, outputs[1]])
|
128 |
+
|
129 |
+
num_layers = len(outputs[1])
|
130 |
+
_,num_heads,_,num_hidden = outputs[1][0][0].detach().numpy().shape
|
131 |
+
encoder = encoderX
|
132 |
+
|
133 |
+
model_metadata = {
|
134 |
+
"encoder" : encoder,
|
135 |
+
"num_heads" : num_heads,
|
136 |
+
"num_hidden" : num_hidden,
|
137 |
+
"num_layers" : num_layers,
|
138 |
+
"model_dim" : -1,
|
139 |
+
"new_state" : True
|
140 |
+
}
|
141 |
+
|
142 |
+
print(model_metadata)
|
143 |
+
|
144 |
+
extra_files = {}
|
145 |
+
extra_files['metadata.json'] = json.dumps(model_metadata)
|
146 |
+
torch.jit.save(
|
147 |
+
traced_script_module, path, _extra_files=extra_files)
|
148 |
+
|
149 |
+
|
150 |
+
class GPT2LMHeadModelWMeta(GPT2LMHeadModel):
|
151 |
+
def extra_repr(self):
|
152 |
+
return "trent is the man"
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
|
156 |
+
import argparse
|
157 |
+
parser = argparse.ArgumentParser()
|
158 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
159 |
+
parser.add_argument("--output", type=str, default="")
|
160 |
+
parser.add_argument("--metadata_path", type=str, default="")
|
161 |
+
parser.add_argument("--config", type=str, default="")
|
162 |
+
parser.add_argument("--encoder", type=str, default="NONE")
|
163 |
+
parser.add_argument("--init", action="store_true")
|
164 |
+
parser.add_argument("--inject", action="store_true")
|
165 |
+
parser.add_argument("--new_state", action="store_true")
|
166 |
+
parser.add_argument("--quantize", action="store_true")
|
167 |
+
parser.add_argument("--prune", action="store_true")
|
168 |
+
parser.add_argument("--control", action="store_true")
|
169 |
+
|
170 |
+
args = parser.parse_args()
|
171 |
+
|
172 |
+
|
173 |
+
if args.inject:
|
174 |
+
assert len(args.metadata_path)
|
175 |
+
inject_metadata(
|
176 |
+
args.ckpt_path, args.metadata_path, args.encoder, True if args.new_state else False)
|
177 |
+
|
178 |
+
else:
|
179 |
+
assert len(args.output)
|
180 |
+
if args.init:
|
181 |
+
encoder_mode = midigpt.getEncoderType(args.encoder)
|
182 |
+
assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER
|
183 |
+
encoder = midigpt.getEncoder(encoder_mode)
|
184 |
+
vocab_size = encoder.vocab_size()
|
185 |
+
if args.control:
|
186 |
+
config = GPT2LMHeadModelContConfig().from_json_file(args.config)
|
187 |
+
# encoder knows the size of the embedding
|
188 |
+
config.n_control_dim = encoder.config.embed_dim
|
189 |
+
model_cls = GPT2LMHeadModelCont
|
190 |
+
|
191 |
+
else:
|
192 |
+
config = GPT2Config().from_json_file(args.config)
|
193 |
+
config.vocab_size = vocab_size
|
194 |
+
model_cls = GPT2LMHeadModel
|
195 |
+
|
196 |
+
model = model_cls(config)
|
197 |
+
else:
|
198 |
+
if args.control:
|
199 |
+
model = GPT2LMHeadModelCont.from_pretrained(args.ckpt_path, torchscript=True)
|
200 |
+
else:
|
201 |
+
model = GPT2LMHeadModel.from_pretrained(args.ckpt_path, torchscript=True)
|
202 |
+
|
203 |
+
convert(model, args.output, quantize=args.quantize, prune=args.prune, control=args.control, ckpt_path=args.ckpt_path, encoderX=args.encoder)
|
204 |
+
|
python_scripts/create_dataset.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import csv
|
6 |
+
from tqdm import tqdm
|
7 |
+
from multiprocessing import Pool
|
8 |
+
|
9 |
+
from utils import *
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import os
|
13 |
+
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
|
14 |
+
import midigpt
|
15 |
+
|
16 |
+
def worker(args):
|
17 |
+
path,sid,labels,nomml,tcjson,encoding = args
|
18 |
+
tc = midigpt.TrainConfig()
|
19 |
+
tc.from_json(tcjson)
|
20 |
+
labels["nomml"] = nomml
|
21 |
+
|
22 |
+
encoder_mode = midigpt.getEncoderType(encoding)
|
23 |
+
assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER
|
24 |
+
encoder = midigpt.getEncoder(encoder_mode)
|
25 |
+
|
26 |
+
try:
|
27 |
+
return sid, midigpt.midi_to_json_bytes(path,tc,json.dumps(labels))
|
28 |
+
except Exception as e:
|
29 |
+
print(e)
|
30 |
+
return None,None
|
31 |
+
|
32 |
+
def load_json(path):
|
33 |
+
if not os.path.exists(path):
|
34 |
+
return {}
|
35 |
+
with open(path, "r") as f:
|
36 |
+
return json.load(f)
|
37 |
+
|
38 |
+
DEFAULT_LABELS = {
|
39 |
+
"genre": "GENRE_MUSICMAP_ANY",
|
40 |
+
"valence_spotify": -1,
|
41 |
+
"energy_spotify": -1,
|
42 |
+
"danceability_spotify": -1,
|
43 |
+
"tension": []
|
44 |
+
}
|
45 |
+
|
46 |
+
DATA_TYPES = [
|
47 |
+
"Drum",
|
48 |
+
"Drum+Music",
|
49 |
+
"Music-No-Drum"
|
50 |
+
]
|
51 |
+
|
52 |
+
def load_metadata_labels(genre_data_path, spotify_data_path, tension_data_path):
|
53 |
+
data = {}
|
54 |
+
genre_data = load_json(genre_data_path)
|
55 |
+
spotify_data = load_json(spotify_data_path)
|
56 |
+
tension_data = load_json(tension_data_path)
|
57 |
+
md5s = list(set(list(genre_data.keys()) + list(spotify_data.keys()) + list(tension_data.keys())))
|
58 |
+
for md5 in md5s:
|
59 |
+
data[md5] = {}
|
60 |
+
if md5 in spotify_data:
|
61 |
+
data[md5]["valence_spotify"] = np.mean(spotify_data[md5]["valence"])
|
62 |
+
data[md5]["energy_spotify"] = np.mean(spotify_data[md5]["energy"])
|
63 |
+
data[md5]["danceability_spotify"] = np.mean(spotify_data[md5]["danceability"])
|
64 |
+
else:
|
65 |
+
for k,v in DEFAULT_LABELS.items():
|
66 |
+
data[md5][k] = v
|
67 |
+
data[md5]["genre"] = genre_data.get(md5, DEFAULT_LABELS["genre"])
|
68 |
+
data[md5]["tension"] = tension_data.get(md5, DEFAULT_LABELS["tension"])
|
69 |
+
return data
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
|
73 |
+
import argparse
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument("--data_dir", type=str, required=True)
|
76 |
+
parser.add_argument("--output", type=str, required=True)
|
77 |
+
parser.add_argument("--num_bars", type=int, default=4)
|
78 |
+
parser.add_argument("--expressive", action="store_true")
|
79 |
+
parser.add_argument("--ignore_score", type=bool, default=0)
|
80 |
+
parser.add_argument("--nthreads", type=int, default=8)
|
81 |
+
parser.add_argument("--max_size", type=int, default=-1)
|
82 |
+
parser.add_argument("--genre_data", type=str, default="")
|
83 |
+
parser.add_argument("--spotify_data", type=str, default="")
|
84 |
+
parser.add_argument("--tension_data", type=str, default="")
|
85 |
+
parser.add_argument("--encoding", type=str, default="TRACK_ENCODER")
|
86 |
+
parser.add_argument("--resolution", type=int, default=12)
|
87 |
+
parser.add_argument("--delta_resolution", type=int, default=1920)
|
88 |
+
parser.add_argument("--metadata", type=str, required=True)
|
89 |
+
parser.add_argument("--type", type=str, default="Drum+Music")
|
90 |
+
parser.add_argument("--test", type=str, default="no")
|
91 |
+
args = parser.parse_args()
|
92 |
+
|
93 |
+
args.ignore_score = bool(args.ignore_score)
|
94 |
+
if args.test != "no":
|
95 |
+
test_script = True
|
96 |
+
else:
|
97 |
+
test_script = False
|
98 |
+
|
99 |
+
assert args.type in DATA_TYPES
|
100 |
+
args.type = "-" + args.type + "-"
|
101 |
+
|
102 |
+
import os
|
103 |
+
os.system("taskset -p 0xffff %d" % os.getpid())
|
104 |
+
|
105 |
+
# multi thread approach takes about 2 minutes
|
106 |
+
pool = Pool(args.nthreads)
|
107 |
+
output = os.path.splitext(args.output)[0]
|
108 |
+
ss=""
|
109 |
+
if args.max_size > 0:
|
110 |
+
ss=f"_MAX_{args.max_size}"
|
111 |
+
if args.expressive:
|
112 |
+
output += "/{}_NUM_BARS={}_RESOLUTION_{}_DELTA_{}{}.arr".format(args.encoding,args.num_bars,args.resolution, args.delta_resolution,ss)
|
113 |
+
else:
|
114 |
+
output += "/{}_NUM_BARS={}_RESOLUTION_{}{}.arr".format(args.encoding,args.num_bars,args.resolution,ss)
|
115 |
+
print(output)
|
116 |
+
if not test_script:
|
117 |
+
jag = midigpt.BytesToFile(output)
|
118 |
+
|
119 |
+
|
120 |
+
paths = list(glob.glob(args.data_dir + "/**/*.mid", recursive=True))
|
121 |
+
|
122 |
+
import random
|
123 |
+
import time
|
124 |
+
random.seed(int(time.time()))
|
125 |
+
|
126 |
+
tc = midigpt.TrainConfig()
|
127 |
+
tc.num_bars = args.num_bars
|
128 |
+
tc.use_microtiming = args.expressive
|
129 |
+
tc.resolution = args.resolution
|
130 |
+
tc.delta_resolution = args.delta_resolution
|
131 |
+
tc = tc.to_json()
|
132 |
+
print(tc)
|
133 |
+
|
134 |
+
paths_exp = []
|
135 |
+
sids_exp = []
|
136 |
+
paths_non_exp = []
|
137 |
+
sids_non_exp = []
|
138 |
+
paths_all = []
|
139 |
+
sids_all = []
|
140 |
+
nomml_alls = []
|
141 |
+
nomml_scores = []
|
142 |
+
|
143 |
+
try:
|
144 |
+
with open(args.metadata) as meta:
|
145 |
+
reader = csv.DictReader(meta, delimiter=',')
|
146 |
+
for row in tqdm(reader):
|
147 |
+
path = row["filepath"]
|
148 |
+
nomml = int(row["medianMetricDepth"])
|
149 |
+
if (".mid" in path and args.type in path):
|
150 |
+
if "-Train-" in path:
|
151 |
+
group = 0
|
152 |
+
elif "-Val-" in path:
|
153 |
+
group = 1
|
154 |
+
elif "-Test-" in path:
|
155 |
+
group = 2
|
156 |
+
else:
|
157 |
+
raise RuntimeError("data format incorrect")
|
158 |
+
if (nomml < 12):
|
159 |
+
paths_non_exp.append(os.path.join(args.data_dir,path))
|
160 |
+
sids_non_exp.append(group)
|
161 |
+
nomml_scores.append(nomml)
|
162 |
+
else:
|
163 |
+
paths_exp.append(os.path.join(args.data_dir,path))
|
164 |
+
sids_exp.append(group)
|
165 |
+
paths_all.append(os.path.join(args.data_dir,path))
|
166 |
+
sids_all.append(group)
|
167 |
+
nomml_alls.append(nomml)
|
168 |
+
|
169 |
+
except:
|
170 |
+
paths_all = list(glob.glob(args.data_dir + "/**/*.mid", recursive=True))
|
171 |
+
for path in paths_all:
|
172 |
+
if "-train-" in path:
|
173 |
+
sids_all.append(0)
|
174 |
+
elif "-valid-" in path:
|
175 |
+
sids_all.append(1)
|
176 |
+
elif "-test-" in path:
|
177 |
+
sids_all.append(2)
|
178 |
+
else:
|
179 |
+
raise RuntimeError("data format incorrect")
|
180 |
+
|
181 |
+
nomml_vals = []
|
182 |
+
if args.expressive:
|
183 |
+
if args.ignore_score:
|
184 |
+
paths = paths_exp
|
185 |
+
sids = sids_exp
|
186 |
+
nomml_vals = [12 for _ in sids]
|
187 |
+
else:
|
188 |
+
paths = paths_all
|
189 |
+
sids = sids_all
|
190 |
+
nomml_vals = nomml_alls
|
191 |
+
else:
|
192 |
+
paths = paths_all
|
193 |
+
sids = sids_all
|
194 |
+
nomml_vals = nomml_alls
|
195 |
+
|
196 |
+
metadata_label_data = load_metadata_labels(args.genre_data, args.spotify_data, args.tension_data)
|
197 |
+
metadata_labels = [metadata_label_data.get(os.path.splitext(os.path.basename(p))[0],DEFAULT_LABELS) for p in paths]
|
198 |
+
print("LOADED {} METADATA LABELS".format(len(metadata_labels)))
|
199 |
+
|
200 |
+
tcs = [tc for _ in paths]
|
201 |
+
encoding = [args.encoding for _ in paths]
|
202 |
+
inputs = list(zip(paths,sids,metadata_labels,nomml_vals,tcs,encoding))
|
203 |
+
random.shuffle(inputs)
|
204 |
+
|
205 |
+
for k,v in DEFAULT_LABELS.items():
|
206 |
+
print("{} FILES HAVE {} METADATA".format(sum([m[k] != v for m in metadata_labels]),k))
|
207 |
+
|
208 |
+
if args.max_size > 0:
|
209 |
+
inputs = inputs[:args.max_size]
|
210 |
+
|
211 |
+
if not test_script:
|
212 |
+
total_count = 0
|
213 |
+
success_count = 0
|
214 |
+
pool = Pool(args.nthreads)
|
215 |
+
progress_bar = tqdm(pool.imap_unordered(worker, inputs), total=len(inputs))
|
216 |
+
for sid,b in progress_bar:
|
217 |
+
if b is not None and len(b):
|
218 |
+
jag.append_bytes_to_file_stream(b,sid)
|
219 |
+
success_count += 1
|
220 |
+
total_count += 1
|
221 |
+
status_str = "{}/{}".format(success_count,total_count)
|
222 |
+
progress_bar.set_description(status_str)
|
223 |
+
jag.close()
|
224 |
+
else:
|
225 |
+
print("Test successful")
|
226 |
+
sys.exit(0)
|
python_scripts/custom_models.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import *
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class GPT2Encoder(GPT2PreTrainedModel):
|
6 |
+
def __init__(self, config):
|
7 |
+
super().__init__(config)
|
8 |
+
self.transformer = GPT2Model(config)
|
9 |
+
self.score = nn.Linear(config.n_embd, 128, bias=False)
|
10 |
+
self.init_weights()
|
11 |
+
self.model_parallel = False
|
12 |
+
self.device_map = None
|
13 |
+
|
14 |
+
negative_importance = torch.tensor(5.33).float()
|
15 |
+
negative_threshold = torch.tensor(4.).float()
|
16 |
+
entropy_importance = torch.tensor(0.05).float()
|
17 |
+
self.register_buffer('negative_importance', negative_importance)
|
18 |
+
self.register_buffer('negative_threshold', negative_threshold)
|
19 |
+
self.register_buffer('entropy_importance', entropy_importance)
|
20 |
+
|
21 |
+
def forward(
|
22 |
+
self,
|
23 |
+
input_ids=None,
|
24 |
+
past_key_values=None,
|
25 |
+
attention_mask=None,
|
26 |
+
token_type_ids=None,
|
27 |
+
position_ids=None,
|
28 |
+
head_mask=None,
|
29 |
+
inputs_embeds=None,
|
30 |
+
labels=None,
|
31 |
+
use_cache=None,
|
32 |
+
output_attentions=None,
|
33 |
+
output_hidden_states=None,
|
34 |
+
return_dict=None,
|
35 |
+
sequence_lengths=None,
|
36 |
+
):
|
37 |
+
assert sequence_lengths is not None
|
38 |
+
outputs = self.transformer(
|
39 |
+
input_ids,
|
40 |
+
past_key_values=past_key_values,
|
41 |
+
attention_mask=attention_mask,
|
42 |
+
token_type_ids=token_type_ids,
|
43 |
+
position_ids=position_ids,
|
44 |
+
head_mask=head_mask,
|
45 |
+
inputs_embeds=inputs_embeds,
|
46 |
+
use_cache=use_cache,
|
47 |
+
output_attentions=output_attentions,
|
48 |
+
output_hidden_states=output_hidden_states,
|
49 |
+
return_dict=return_dict,
|
50 |
+
)
|
51 |
+
hidden_states = outputs[0]
|
52 |
+
logits = self.score(hidden_states)
|
53 |
+
return logits[range(input_ids.shape[0]),sequence_lengths-1]
|
54 |
+
|
55 |
+
class GPT2LMHeadModelContConfig(GPT2Config):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
n_control_embd=64,
|
59 |
+
n_control_dim=3,
|
60 |
+
**kwargs
|
61 |
+
):
|
62 |
+
super().__init__(**kwargs)
|
63 |
+
self.n_control_embd = n_control_embd
|
64 |
+
self.n_control_dim = n_control_dim
|
65 |
+
|
66 |
+
class GPT2LMHeadModelCont(GPT2LMHeadModel):
|
67 |
+
def __init__(self, config):
|
68 |
+
super().__init__(config)
|
69 |
+
token_embd = config.n_embd - config.n_control_embd
|
70 |
+
self.wte = nn.Embedding(config.vocab_size, token_embd)
|
71 |
+
self.ctrle = nn.Linear(config.n_control_dim, config.n_control_embd)
|
72 |
+
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
input_ids=None,
|
76 |
+
control_ids=None,
|
77 |
+
past_key_values=None,
|
78 |
+
attention_mask=None,
|
79 |
+
token_type_ids=None,
|
80 |
+
position_ids=None,
|
81 |
+
head_mask=None,
|
82 |
+
inputs_embeds=None,
|
83 |
+
labels=None,
|
84 |
+
use_cache=None,
|
85 |
+
output_attentions=None,
|
86 |
+
output_hidden_states=None,
|
87 |
+
return_dict=None,
|
88 |
+
sequence_lengths=None
|
89 |
+
):
|
90 |
+
shape = control_ids.shape
|
91 |
+
input_shape = (shape[0]*shape[1],shape[2])
|
92 |
+
output_shape = (shape[0],shape[1],self.config.n_control_embd)
|
93 |
+
control_embd = self.ctrle(torch.reshape(control_ids, input_shape))
|
94 |
+
control_embd = torch.reshape(control_embd, output_shape)
|
95 |
+
token_embd = self.wte(input_ids)
|
96 |
+
inputs_embeds = torch.cat([token_embd,control_embd], axis=-1)
|
97 |
+
return super().forward(
|
98 |
+
past_key_values=past_key_values,
|
99 |
+
attention_mask=attention_mask,
|
100 |
+
inputs_embeds=inputs_embeds,
|
101 |
+
labels=labels
|
102 |
+
)
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
|
106 |
+
batch_size = 3
|
107 |
+
|
108 |
+
config = GPT2Config().from_json_file("config/gpt2_tiny.json")
|
109 |
+
|
110 |
+
#model = GPT2LMHeadModelCont(config)
|
111 |
+
model = GPT2Encoder(config)
|
112 |
+
|
113 |
+
kwargs = {
|
114 |
+
"input_ids" : torch.randint(config.vocab_size, size=(batch_size,100)),
|
115 |
+
"labels" : torch.randint(config.vocab_size, size=(batch_size,100)),
|
116 |
+
#"control_ids" : torch.randint(CONTROL_VOCAB_SIZE, size=(1,100))
|
117 |
+
"sequence_lengths" : [99,90,90]
|
118 |
+
}
|
119 |
+
|
120 |
+
out = model.forward(**kwargs)
|
121 |
+
print(out.shape)
|
python_scripts/data_split.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import listdir
|
2 |
+
from os.path import isfile, join
|
3 |
+
from shutil import move
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
|
10 |
+
root = sys.argv[1]
|
11 |
+
new_root = sys.argv[2]
|
12 |
+
|
13 |
+
# Get all midi files
|
14 |
+
print("Getting MIDI files")
|
15 |
+
onlyfiles = [f for f in listdir(root)]
|
16 |
+
n = len(onlyfiles)
|
17 |
+
print("Num files: " + str(n))
|
18 |
+
|
19 |
+
# Generate random test/train/valid indices
|
20 |
+
|
21 |
+
idx = [i for i in range(n)]
|
22 |
+
split_idx = random.shuffle(idx)
|
23 |
+
|
24 |
+
train_len = int(0.8 * n)
|
25 |
+
test_len = int(0.1 * n)
|
26 |
+
valid_len = n - train_len - test_len
|
27 |
+
|
28 |
+
train_idx = idx[:train_len]
|
29 |
+
test_idx = idx[train_len:test_len + train_len]
|
30 |
+
valid_idx = idx[test_len + train_len:]
|
31 |
+
|
32 |
+
# Move files to respective folder
|
33 |
+
|
34 |
+
o = 0
|
35 |
+
|
36 |
+
print('Spliting Train Set')
|
37 |
+
for i in tqdm(range(train_len)):
|
38 |
+
move(join(root, onlyfiles[train_idx[i]]), join(new_root, "train", onlyfiles[train_idx[i]]))
|
39 |
+
o += 1
|
40 |
+
|
41 |
+
print('Spliting Test Set')
|
42 |
+
for i in tqdm(range(test_len)):
|
43 |
+
move(join(root, onlyfiles[test_idx[i]]), join(new_root, "test", onlyfiles[test_idx[i]]))
|
44 |
+
o += 1
|
45 |
+
|
46 |
+
print('Spliting Validation Set')
|
47 |
+
for i in tqdm(range(valid_len)):
|
48 |
+
move(join(root, onlyfiles[valid_idx[i]]), join(new_root, "valid", onlyfiles[valid_idx[i]]))
|
49 |
+
o += 1
|
50 |
+
|
51 |
+
print("Succes Test: " + str(o == n))
|
52 |
+
|
53 |
+
print(join(new_root, "valid", onlyfiles[valid_idx[100]]))
|
54 |
+
print(isfile(join(new_root, "valid", onlyfiles[valid_idx[100]])))
|
python_scripts/losses.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def standard_loss(self, model, inputs, return_outputs=False):
|
4 |
+
outputs = model(**inputs)
|
5 |
+
if self.args.past_index >= 0:
|
6 |
+
self._past = outputs[self.args.past_index]
|
7 |
+
loss = outputs[0].mean()
|
8 |
+
print("loss : ", loss)
|
9 |
+
return (loss,outputs) if return_outputs else loss
|
10 |
+
|
11 |
+
def hinge_cost(m, a, b):
|
12 |
+
dist = m - torch.sqrt(torch.sum((a - b)**2, axis=1))
|
13 |
+
return torch.mean(torch.clamp(dist,0,float('inf'))**2)
|
14 |
+
|
15 |
+
def sim_metric_loss(self, model, inputs, return_outputs=False):
|
16 |
+
|
17 |
+
# single pass version
|
18 |
+
batch_size = len(inputs["input_ids"])//4
|
19 |
+
outputs = model(**inputs)
|
20 |
+
x_p = outputs[:batch_size]
|
21 |
+
x_n = outputs[batch_size:2*batch_size]
|
22 |
+
y_p = outputs[2*batch_size:3*batch_size]
|
23 |
+
y_n = outputs[3*batch_size:]
|
24 |
+
|
25 |
+
model_attr = model
|
26 |
+
if isinstance(model, torch.nn.DataParallel):
|
27 |
+
model_attr = model.module
|
28 |
+
|
29 |
+
cost_p = torch.mean(torch.sum((x_p - y_p)**2, axis=1))
|
30 |
+
cost_n = model_attr.negative_importance*hinge_cost(
|
31 |
+
model_attr.negative_threshold, x_n, y_n)
|
32 |
+
cost_e = model_attr.entropy_importance*torch.mean(
|
33 |
+
torch.sum(x_p**2, axis=1) + torch.sum(y_p**2, axis=1))
|
34 |
+
loss = cost_p + cost_n + cost_e
|
35 |
+
|
36 |
+
print(loss)
|
37 |
+
return (loss,None) if return_outputs else loss
|
python_scripts/train.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
|
4 |
+
import midigpt
|
5 |
+
|
6 |
+
from transformers import *
|
7 |
+
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import torch
|
12 |
+
|
13 |
+
import datetime
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from subprocess import check_output
|
18 |
+
|
19 |
+
from losses import sim_metric_loss, standard_loss
|
20 |
+
from custom_models import *
|
21 |
+
from train_dataset import *
|
22 |
+
from transformers import Trainer, TrainingArguments
|
23 |
+
from callbacks import MemoryUsageCallback, ProfilerCallback
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
|
27 |
+
import argparse
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--arch", type=str, required=True)
|
30 |
+
parser.add_argument("--config", type=str, required=True)
|
31 |
+
parser.add_argument("--encoding", type=str, required=True)
|
32 |
+
parser.add_argument("--dataset", type=str, required=True)
|
33 |
+
parser.add_argument("--pad_value", type=int, default=-100)
|
34 |
+
|
35 |
+
parser.add_argument("--expressive", action="store_true")
|
36 |
+
parser.add_argument("--num_bars", type=int, default=4)
|
37 |
+
parser.add_argument("--min_tracks", type=int, default=2)
|
38 |
+
parser.add_argument("--max_tracks", type=int, default=12)
|
39 |
+
parser.add_argument("--max_seq_len", type=int, default=2048)
|
40 |
+
parser.add_argument("--no_max_length", type=int, default=0)
|
41 |
+
parser.add_argument("--resolution", type=int, default=12)
|
42 |
+
parser.add_argument("--delta_resolution", type=int, default=1920)
|
43 |
+
parser.add_argument("--abs_pos_vocab_size", type=int, default=196)
|
44 |
+
parser.add_argument("--delta_vocab_size", type=int, default=96)
|
45 |
+
|
46 |
+
parser.add_argument("--ngpu", type=int, default=4)
|
47 |
+
parser.add_argument("--accum_steps", type=int, default=1)
|
48 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
49 |
+
parser.add_argument("--batches_per_epoch", type=int, default=1000)
|
50 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
51 |
+
|
52 |
+
parser.add_argument("--overwrite", type=int, default=1)
|
53 |
+
parser.add_argument("--save_steps", type=int, default=5000)
|
54 |
+
parser.add_argument("--log_steps", type=int, default=100)
|
55 |
+
parser.add_argument("--step", type=int, default=0)
|
56 |
+
parser.add_argument("--label", type=str, default="version3")
|
57 |
+
parser.add_argument("--profiler_steps", type=int, default=50)
|
58 |
+
|
59 |
+
parser.add_argument("--dry", action="store_true")
|
60 |
+
parser.add_argument("--metric", action="store_true")
|
61 |
+
|
62 |
+
parser.add_argument("--ckpt", type=str, default="")
|
63 |
+
parser.add_argument("--ckpt_num", type=int, default=5000)
|
64 |
+
parser.add_argument("--output", type=str, default="")
|
65 |
+
parser.add_argument("--log", type=str, default="")
|
66 |
+
|
67 |
+
parser.add_argument("--test_only", action="store_true")
|
68 |
+
parser.add_argument("--memory_metrics", action="store_true")
|
69 |
+
|
70 |
+
args = parser.parse_args()
|
71 |
+
args.expressive = (args.encoding == "EXPRESSIVE_ENCODER") and args.expressive
|
72 |
+
|
73 |
+
dataset_cls = CustomDataset
|
74 |
+
loss_fn = standard_loss
|
75 |
+
|
76 |
+
np.random.seed(int(time.time()))
|
77 |
+
|
78 |
+
# determine vocab size
|
79 |
+
date_str = datetime.datetime.now().strftime('%b_%d_%H_%M')
|
80 |
+
encoder_mode = midigpt.getEncoderType(args.encoding)
|
81 |
+
assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER
|
82 |
+
encoder = midigpt.getEncoder(encoder_mode)
|
83 |
+
if args.expressive:
|
84 |
+
encoder.set_scheme(args.resolution, args.delta_resolution, args.delta_vocab_size, args.abs_pos_vocab_size)
|
85 |
+
vocab_size = encoder.vocab_size()
|
86 |
+
|
87 |
+
current_git_commit_hash = check_output(["git", "rev-parse", "HEAD"], text=True).strip()
|
88 |
+
|
89 |
+
load_checkpoint = False
|
90 |
+
if args.ckpt == "":
|
91 |
+
name = "_".join([args.encoding, args.arch, args.label, date_str, "num_bars", str(args.num_bars), str(args.max_tracks), "GIT_HASH", current_git_commit_hash])
|
92 |
+
else:
|
93 |
+
name = args.ckpt
|
94 |
+
load_checkpoint = True
|
95 |
+
|
96 |
+
if args.dry:
|
97 |
+
while True:
|
98 |
+
dataset = dataset_cls(split_id=0, is_training=True, **vars(args))
|
99 |
+
for batch in tqdm(dataset,smoothing=0):
|
100 |
+
np_inputs = batch["input_ids"].detach().numpy()
|
101 |
+
print( [encoder.pretty(t) for t in np_inputs[0][:100]] )
|
102 |
+
print( {k:v.shape for k,v in batch.items()} )
|
103 |
+
|
104 |
+
if os.getenv("SLURM_TMPDIR") is not None:
|
105 |
+
# we are on compute canada and should attempt to copy
|
106 |
+
# dataset to tmpdir for faster access
|
107 |
+
from shutil import copyfile
|
108 |
+
tmpdir = os.getenv("SLURM_TMPDIR")
|
109 |
+
dataset_path = os.path.join(tmpdir, os.path.basename(args.dataset))
|
110 |
+
if not os.path.exists(dataset_path):
|
111 |
+
copyfile(args.dataset, dataset_path)
|
112 |
+
copyfile(args.dataset + ".header", dataset_path + ".header")
|
113 |
+
args.dataset = dataset_path
|
114 |
+
|
115 |
+
# setup datasets
|
116 |
+
train_dataset = dataset_cls(split_id=0, is_training=True, **vars(args))
|
117 |
+
eval_dataset = dataset_cls(split_id=2, is_training=False, overload_batches_per_epoch=1, **vars(args))
|
118 |
+
Trainer.get_train_dataloader = lambda *_args,**_kwargs: train_dataset
|
119 |
+
Trainer.get_eval_dataloader = lambda *_args,**_kwargs: eval_dataset
|
120 |
+
Trainer.compute_loss = loss_fn
|
121 |
+
|
122 |
+
print("MODEL NAME : " + name)
|
123 |
+
print("VOCAB SIZE : " + str(vocab_size))
|
124 |
+
print("ARGS : " + json.dumps(vars(args),indent=4))
|
125 |
+
print("MODEL CONFIG : " + json.dumps(json.load(open(args.config,"r")),indent=4))
|
126 |
+
print("ENCODER CONFIG : " + json.dumps(encoder.config.ToJson(),indent=4))
|
127 |
+
|
128 |
+
logging_dir = os.path.join(args.log, "{}".format(name))
|
129 |
+
output_dir = os.path.join(args.output, "checkpoints/{}".format(name))
|
130 |
+
|
131 |
+
print("LOGGING PATH : " + logging_dir)
|
132 |
+
print("OUTPUT PATH : " + output_dir)
|
133 |
+
|
134 |
+
os.makedirs(logging_dir, exist_ok=True)
|
135 |
+
os.makedirs(output_dir, exist_ok=True)
|
136 |
+
|
137 |
+
# =================================================================
|
138 |
+
# model selection
|
139 |
+
|
140 |
+
if args.arch == "gpt2":
|
141 |
+
config = GPT2Config().from_json_file(args.config)
|
142 |
+
model_cls = GPT2LMHeadModel
|
143 |
+
elif args.arch == "xl":
|
144 |
+
config = TransfoXLConfig().from_json_file(args.config)
|
145 |
+
model_cls = TransfoXLLMHeadModel
|
146 |
+
elif args.arch == "metric":
|
147 |
+
config = GPT2Config().from_json_file(args.config)
|
148 |
+
model_cls = GPT2Encoder
|
149 |
+
elif args.arch == "control":
|
150 |
+
config = GPT2LMHeadModelContConfig().from_json_file(args.config)
|
151 |
+
# encoder knows the size of the embedding
|
152 |
+
config.n_control_dim = encoder.config.embed_dim
|
153 |
+
model_cls = GPT2LMHeadModelCont
|
154 |
+
elif args.arch == "bert":
|
155 |
+
config = BertConfig().from_json_file(args.config)
|
156 |
+
model_cls = BertForMaskedLM
|
157 |
+
else:
|
158 |
+
raise NotImplementedError
|
159 |
+
|
160 |
+
config.vocab_size = vocab_size
|
161 |
+
print("MODEL CONFIG : " + str(config))
|
162 |
+
|
163 |
+
|
164 |
+
if len(args.ckpt.strip()) == 0:
|
165 |
+
print('Model initialization')
|
166 |
+
ckpt_path = None
|
167 |
+
model = model_cls(config)
|
168 |
+
else:
|
169 |
+
try:
|
170 |
+
print('Trying to load checkpoint')
|
171 |
+
ckpt_path = os.path.join(output_dir, f"checkpoint-{args.ckpt_num}")
|
172 |
+
model = model_cls.from_pretrained(ckpt_path)
|
173 |
+
except Exception as e:
|
174 |
+
print(e)
|
175 |
+
print('Returning to default model initialization')
|
176 |
+
model = model_cls(config)
|
177 |
+
|
178 |
+
|
179 |
+
# Create Memory metrics callback
|
180 |
+
|
181 |
+
# =================================================================
|
182 |
+
# training
|
183 |
+
|
184 |
+
training_args = TrainingArguments(
|
185 |
+
logging_dir=logging_dir,
|
186 |
+
report_to="tensorboard",
|
187 |
+
output_dir=output_dir,
|
188 |
+
overwrite_output_dir=bool(args.overwrite),
|
189 |
+
num_train_epochs=(500000/args.batches_per_epoch)*args.accum_steps,
|
190 |
+
logging_steps=args.log_steps,
|
191 |
+
save_steps=args.save_steps,
|
192 |
+
save_total_limit=None,
|
193 |
+
learning_rate=args.lr,
|
194 |
+
gradient_accumulation_steps=args.accum_steps,
|
195 |
+
per_device_train_batch_size=args.batch_size//args.ngpu//args.accum_steps,
|
196 |
+
per_device_eval_batch_size=args.batch_size//args.ngpu//args.accum_steps,
|
197 |
+
evaluation_strategy="epoch",
|
198 |
+
prediction_loss_only=True,
|
199 |
+
skip_memory_metrics=True
|
200 |
+
)
|
201 |
+
|
202 |
+
# For custom memory metrics, don't work and multiply by 100 training time!!!
|
203 |
+
if args.memory_metrics:
|
204 |
+
callbacks = [MemoryUsageCallback, ProfilerCallback]
|
205 |
+
else:
|
206 |
+
callbacks = []
|
207 |
+
|
208 |
+
trainer = Trainer(
|
209 |
+
model=model,
|
210 |
+
args=training_args,
|
211 |
+
data_collator=None,
|
212 |
+
train_dataset=None,
|
213 |
+
eval_dataset=None,
|
214 |
+
callbacks=callbacks
|
215 |
+
)
|
216 |
+
|
217 |
+
trainer.train_dataset = train_dataset
|
218 |
+
trainer.eval_dataset = eval_dataset
|
219 |
+
|
220 |
+
if not args.test_only:
|
221 |
+
trainer.train(ckpt_path)
|
222 |
+
else:
|
223 |
+
model = trainer._wrap_model(trainer.model)
|
224 |
+
model.save_pretrained(output_dir)
|
python_scripts/train_dataset.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#from transformers import Trainer, TrainingArguments
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
#from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
import datetime
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import sys
|
14 |
+
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
|
15 |
+
import midigpt
|
16 |
+
|
17 |
+
class CustomDataset:
|
18 |
+
def __init__(self, split_id=0, is_training=True, batch_size=32, dataset=None, num_bars=4, min_tracks=2, max_tracks=12, max_seq_len=2048, expressive=False, no_max_length=False, resolution=12, encoding=None, pad_value=-100, arch="gpt2", accum_steps=1, batches_per_epoch=1000, overload_batches_per_epoch=None, **kwargs):
|
19 |
+
# settings
|
20 |
+
self.is_training = is_training
|
21 |
+
self.batch_size = batch_size // accum_steps
|
22 |
+
self.split_id = split_id
|
23 |
+
self.max_seq_len = max_seq_len
|
24 |
+
self.batches_per_epoch = batches_per_epoch if overload_batches_per_epoch is None else overload_batches_per_epoch
|
25 |
+
self.dataset = list(range(self.batches_per_epoch)) # number of examples ??
|
26 |
+
self.pad_value = pad_value
|
27 |
+
self.arch = arch
|
28 |
+
|
29 |
+
# create dataloader
|
30 |
+
self.dataloader = midigpt.Jagged(dataset)
|
31 |
+
self.dataloader.set_num_bars(num_bars)
|
32 |
+
self.dataloader.set_min_tracks(min_tracks)
|
33 |
+
self.dataloader.set_max_tracks(max_tracks)
|
34 |
+
self.dataloader.set_max_seq_len(max_seq_len)
|
35 |
+
seed = np.random.randint(2**20)
|
36 |
+
self.dataloader.set_seed(seed)
|
37 |
+
self.encoder_mode = midigpt.getEncoderType(encoding)
|
38 |
+
|
39 |
+
# create train_config
|
40 |
+
self.tc = midigpt.TrainConfig()
|
41 |
+
self.tc.num_bars = num_bars
|
42 |
+
self.tc.min_tracks = min_tracks
|
43 |
+
self.tc.max_tracks = max_tracks
|
44 |
+
self.tc.use_microtiming = expressive
|
45 |
+
self.tc.no_max_length = no_max_length
|
46 |
+
self.tc.resolution = resolution
|
47 |
+
|
48 |
+
self.current = 0
|
49 |
+
|
50 |
+
def _get_batch(self):
|
51 |
+
input_ids, mask = self.dataloader.read_batch_v2(
|
52 |
+
self.batch_size, self.split_id, self.encoder_mode, self.tc)
|
53 |
+
input_ids = np.array(input_ids)
|
54 |
+
mask = np.array(mask)
|
55 |
+
labels = np.copy(input_ids)
|
56 |
+
labels += (1-mask) * self.pad_value # set masked tokens to pad_value
|
57 |
+
batch = {
|
58 |
+
"input_ids" : torch.from_numpy(input_ids),
|
59 |
+
"attention_mask" : torch.from_numpy(mask),
|
60 |
+
"labels" : torch.from_numpy(labels)
|
61 |
+
}
|
62 |
+
if self.arch == "xl":
|
63 |
+
batch.pop("attention_mask")
|
64 |
+
assert np.all(np.sum(mask,axis=1)==self.max_seq_len)
|
65 |
+
if self.arch == "bert":
|
66 |
+
batch.pop("labels")
|
67 |
+
return batch
|
68 |
+
|
69 |
+
def _get_batch_test(self):
|
70 |
+
inputs = torch.ones((32,800), dtype=torch.int64)
|
71 |
+
return {
|
72 |
+
"input_ids" : inputs,
|
73 |
+
"labels" : inputs
|
74 |
+
}
|
75 |
+
|
76 |
+
def __iter__(self):
|
77 |
+
self.current = 0
|
78 |
+
return self
|
79 |
+
|
80 |
+
def __next__(self):
|
81 |
+
self.current += 1
|
82 |
+
if self.current <= self.batches_per_epoch:
|
83 |
+
while True:
|
84 |
+
try:
|
85 |
+
return self._get_batch()
|
86 |
+
except Exception as e:
|
87 |
+
print("ERROR IN BATCHER : ", e)
|
88 |
+
raise StopIteration
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return self.batches_per_epoch
|
92 |
+
|
93 |
+
def pad(seqs, pad_value):
|
94 |
+
seqlens = np.array([len(seq) for seq in seqs])
|
95 |
+
maxlen = np.max(seqlens)
|
96 |
+
return np.array([np.pad(seq, (0,maxlen-len(seq)), mode="constant", constant_values=pad_value) for seq in seqs]), seqlens
|
python_scripts/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import jsonlines
|
3 |
+
from multiprocessing import Pool
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def load_json(path):
|
7 |
+
with open(path,"r") as f:
|
8 |
+
return json.load(f)
|
9 |
+
|
10 |
+
def dump_json(x, path):
|
11 |
+
with open(path,"w") as f:
|
12 |
+
json.dump(x,f,indent=4)
|
13 |
+
|
14 |
+
def apply_func(x, func):
|
15 |
+
pool = Pool(8)
|
16 |
+
for inval,outval in tqdm(pool.imap_unordered(func,x),total=len(x)):
|
17 |
+
yield inval,outval
|
18 |
+
|
19 |
+
def load_jsonl(path,max_items=None):
|
20 |
+
with jsonlines.open(path) as reader:
|
21 |
+
for ii,item in enumerate(reader):
|
22 |
+
yield item
|
23 |
+
if max_items is not None and ii >= max_items:
|
24 |
+
break
|
25 |
+
|
26 |
+
def dump_jsonl(data, path):
|
27 |
+
assert isinstance(data, list)
|
28 |
+
with jsonlines.open(path, mode="w") as wr:
|
29 |
+
for item in tqdm(data,leave=False):
|
30 |
+
wr.write(item)
|
31 |
+
|
32 |
+
class dump_jsonl_multistage:
|
33 |
+
def __init__(self, path, mode="a"):
|
34 |
+
self.wr = jsonlines.open(path, mode=mode, flush=True)
|
35 |
+
def add(self, item):
|
36 |
+
self.wr.write(item)
|
37 |
+
def extend(self, items):
|
38 |
+
for item in items:
|
39 |
+
self.add(item)
|
40 |
+
def close(self):
|
41 |
+
self.wr.close()
|
python_scripts_for_testing/midigpt_gen.mid
ADDED
Binary file (591 Bytes). View file
|
|
python_scripts_for_testing/mtest.mid
ADDED
Binary file (455 Bytes). View file
|
|
python_scripts_for_testing/pythoninferencetest.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib")
|
3 |
+
import midigpt
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--midi", type=str, required=True)
|
12 |
+
parser.add_argument("--ckpt", type=str, required=True)
|
13 |
+
parser.add_argument("--out", type=str, default='')
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
ckpt = args.ckpt
|
17 |
+
midi_input = args.midi
|
18 |
+
if args.out != '':
|
19 |
+
midi_dest = args.out
|
20 |
+
else:
|
21 |
+
midi_dest = os.path.join(os.path.split(midi_input)[0], 'midigpt_gen.mid')
|
22 |
+
e = midigpt.ExpressiveEncoder()
|
23 |
+
midi_json_input = json.loads(e.midi_to_json(midi_input))
|
24 |
+
valid_status={'tracks':
|
25 |
+
[
|
26 |
+
{
|
27 |
+
'track_id': 0,
|
28 |
+
'temperature' : 0.5,
|
29 |
+
'instrument': 'acoustic_grand_piano',
|
30 |
+
'density': 10,
|
31 |
+
'track_type': 10,
|
32 |
+
'ignore': False,
|
33 |
+
'selected_bars': [False, False, True, False ],
|
34 |
+
'min_polyphony_q': 'POLYPHONY_ANY',
|
35 |
+
'max_polyphony_q': 'POLYPHONY_ANY',
|
36 |
+
'autoregressive': False,
|
37 |
+
'polyphony_hard_limit': 9
|
38 |
+
}
|
39 |
+
]
|
40 |
+
}
|
41 |
+
parami={
|
42 |
+
'tracks_per_step': 1,
|
43 |
+
'bars_per_step': 1,
|
44 |
+
'model_dim': 4,
|
45 |
+
'percentage': 100,
|
46 |
+
'batch_size': 1,
|
47 |
+
'temperature': 1.0,
|
48 |
+
'max_steps': 200,
|
49 |
+
'polyphony_hard_limit': 6,
|
50 |
+
'shuffle': True,
|
51 |
+
'verbose': True,
|
52 |
+
'ckpt': ckpt,
|
53 |
+
'sampling_seed': -1,
|
54 |
+
'mask_top_k': 0
|
55 |
+
}
|
56 |
+
|
57 |
+
piece = json.dumps(midi_json_input)
|
58 |
+
status = json.dumps(valid_status)
|
59 |
+
param = json.dumps(parami)
|
60 |
+
callbacks = midigpt.CallbackManager()
|
61 |
+
max_attempts = 3
|
62 |
+
midi_str = midigpt.sample_multi_step(piece, status, param, max_attempts, callbacks)
|
63 |
+
midi_str=midi_str[0]
|
64 |
+
midi_json = json.loads(midi_str)
|
65 |
+
|
66 |
+
e = midigpt.ExpressiveEncoder()
|
67 |
+
e.json_to_midi(midi_str, midi_dest)
|
setup.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#CODE IN PROGRESS, NOT WORKING YET
|
2 |
+
|
3 |
+
from setuptools import setup, Extension
|
4 |
+
from setuptools.command.build_ext import build_ext
|
5 |
+
import sys
|
6 |
+
import setuptools
|
7 |
+
|
8 |
+
__version__ = '0.0.1'
|
9 |
+
|
10 |
+
class get_pybind_include(object):
|
11 |
+
"""Helper class to determine the pybind11 include path"""
|
12 |
+
|
13 |
+
def __init__(self, user=False):
|
14 |
+
self.user = user
|
15 |
+
|
16 |
+
def __str__(self):
|
17 |
+
import pybind11
|
18 |
+
return pybind11.get_include(self.user)
|
19 |
+
|
20 |
+
ext_modules = [
|
21 |
+
Extension(
|
22 |
+
'midigpt',
|
23 |
+
['src/lib.cpp','src/common/data_structures/train_config.cpp','src/dataset_creation/compression/lz4.c',
|
24 |
+
'src/dataset_creation/dataset_manipulation/bytes_to_file.cpp'],
|
25 |
+
include_dirs=[
|
26 |
+
get_pybind_include(),
|
27 |
+
get_pybind_include(user=True)
|
28 |
+
],
|
29 |
+
language='c++'
|
30 |
+
),
|
31 |
+
]
|
32 |
+
|
33 |
+
class BuildExt(build_ext):
|
34 |
+
"""A custom build extension for adding compiler-specific options."""
|
35 |
+
c_opts = {
|
36 |
+
'msvc': ['/EHsc'],
|
37 |
+
'unix': [],
|
38 |
+
}
|
39 |
+
|
40 |
+
if sys.platform == 'darwin':
|
41 |
+
c_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7']
|
42 |
+
|
43 |
+
def build_extensions(self):
|
44 |
+
ct = self.compiler.compiler_type
|
45 |
+
opts = self.c_opts.get(ct, [])
|
46 |
+
if ct == 'unix':
|
47 |
+
opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version())
|
48 |
+
opts.append('-std=c++20')
|
49 |
+
elif ct == 'msvc':
|
50 |
+
opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version())
|
51 |
+
for ext in self.extensions:
|
52 |
+
ext.extra_compile_args = opts
|
53 |
+
build_ext.build_extensions(self)
|
54 |
+
|
55 |
+
setup(
|
56 |
+
name='midigpt',
|
57 |
+
version=__version__,
|
58 |
+
author='Jeff Ens, Rafael Arias',
|
59 |
+
author_email='[email protected]',
|
60 |
+
url='',
|
61 |
+
description='A Python wrapper for midigpt project',
|
62 |
+
long_description='',
|
63 |
+
ext_modules=ext_modules,
|
64 |
+
install_requires=['pybind11>=2.5.0'],
|
65 |
+
cmdclass={'build_ext': BuildExt},
|
66 |
+
zip_safe=False,
|
67 |
+
)
|
src/common/data_structures/encoder_config.h
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
#include <tuple>
|
5 |
+
#include <map>
|
6 |
+
#include <random>
|
7 |
+
|
8 |
+
namespace data_structures {
|
9 |
+
class EncoderConfig {
|
10 |
+
public:
|
11 |
+
EncoderConfig() {
|
12 |
+
both_in_one = false;
|
13 |
+
unquantized = false;
|
14 |
+
do_multi_fill = false;
|
15 |
+
use_velocity_levels = false;
|
16 |
+
use_microtiming = false;
|
17 |
+
transpose = 0;
|
18 |
+
resolution = 12;
|
19 |
+
decode_resolution = resolution;
|
20 |
+
decode_final = false;
|
21 |
+
delta_resolution = 1920;
|
22 |
+
}
|
23 |
+
|
24 |
+
std::map<std::string, std::string> ToJson() {
|
25 |
+
std::map<std::string, std::string> json_config;
|
26 |
+
|
27 |
+
json_config["both_in_one"] = std::to_string((int)both_in_one);
|
28 |
+
json_config["unquantized"] = std::to_string((int)unquantized);
|
29 |
+
json_config["do_multi_fill"] = std::to_string((int)do_multi_fill);
|
30 |
+
json_config["use_velocity_levels"] = std::to_string((int)use_velocity_levels);
|
31 |
+
json_config["use_microtiming"] = std::to_string((int)use_microtiming);
|
32 |
+
json_config["transpose"] = std::to_string(transpose);
|
33 |
+
json_config["resolution"] = std::to_string(resolution);
|
34 |
+
json_config["decode_resolution"] = std::to_string(decode_resolution);
|
35 |
+
json_config["decode_final"] = std::to_string((int)decode_final);
|
36 |
+
json_config["delta_resolution"] = std::to_string(delta_resolution);
|
37 |
+
return json_config;
|
38 |
+
}
|
39 |
+
|
40 |
+
void FromJson(const std::map<std::string, std::string>& json_config) {
|
41 |
+
try {
|
42 |
+
both_in_one = (bool)std::stoi(json_config.at("both_in_one"));
|
43 |
+
unquantized = (bool)std::stoi(json_config.at("unquantized"));
|
44 |
+
do_multi_fill = (bool)std::stoi(json_config.at("do_multi_fill"));
|
45 |
+
use_velocity_levels = (bool)std::stoi(json_config.at("use_velocity_levels"));
|
46 |
+
use_microtiming = (bool)std::stoi(json_config.at("use_microtiming"));
|
47 |
+
transpose = std::stoi(json_config.at("transpose"));
|
48 |
+
resolution = std::stoi(json_config.at("resolution"));
|
49 |
+
decode_resolution = std::stoi(json_config.at("decode_resolution"));
|
50 |
+
decode_final = (bool)std::stoi(json_config.at("decode_final"));
|
51 |
+
delta_resolution = std::stoi(json_config.at("delta_resolution"));
|
52 |
+
} catch (const std::out_of_range& e) {
|
53 |
+
throw std::invalid_argument("Missing required key in JSON config: " + std::string(e.what()));
|
54 |
+
} catch (const std::invalid_argument& e) {
|
55 |
+
throw std::invalid_argument("Invalid value type in JSON config: " + std::string(e.what()));
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
int delta_to_step(int delta, int res) {
|
60 |
+
if (!use_microtiming) {
|
61 |
+
return 0;
|
62 |
+
} else {
|
63 |
+
return (int)(delta * res / delta_resolution);
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
int step_to_delta(float step, int res) {
|
68 |
+
if (!use_microtiming) {
|
69 |
+
return 0;
|
70 |
+
} else {
|
71 |
+
return round(delta_resolution * step / res);
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
int step_to_delta(int step, int res) {
|
76 |
+
if (!use_microtiming) {
|
77 |
+
return 0;
|
78 |
+
} else {
|
79 |
+
return round(delta_resolution * step / res);
|
80 |
+
}
|
81 |
+
}
|
82 |
+
|
83 |
+
bool both_in_one;
|
84 |
+
bool unquantized;
|
85 |
+
bool do_multi_fill;
|
86 |
+
bool use_velocity_levels;
|
87 |
+
bool use_microtiming;
|
88 |
+
int transpose;
|
89 |
+
int resolution;
|
90 |
+
int decode_resolution;
|
91 |
+
bool decode_final;
|
92 |
+
int delta_resolution;
|
93 |
+
std::set<std::tuple<int, int>> multi_fill;
|
94 |
+
};
|
95 |
+
}
|
src/common/data_structures/token_sequence.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
#include "midi.pb.h"
|
5 |
+
#include "../encoder/representation.h"
|
6 |
+
|
7 |
+
namespace data_structures {
|
8 |
+
|
9 |
+
class TokenSequence {
|
10 |
+
public:
|
11 |
+
TokenSequence(const std::shared_ptr<encoder::REPRESENTATION> &rep) {
|
12 |
+
bar_num = 0;
|
13 |
+
track_num = 0;
|
14 |
+
}
|
15 |
+
void push_back( int token ) {
|
16 |
+
tokens.push_back( token );
|
17 |
+
}
|
18 |
+
|
19 |
+
void insert( std::vector<int> &tokens) {
|
20 |
+
for (auto token : tokens) {
|
21 |
+
push_back(token);
|
22 |
+
}
|
23 |
+
}
|
24 |
+
|
25 |
+
void on_track_start(midi::Piece *x, const std::shared_ptr<encoder::REPRESENTATION> &rep) {
|
26 |
+
track_num++;
|
27 |
+
bar_num = 0;
|
28 |
+
}
|
29 |
+
void on_bar_start(midi::Piece *x, const std::shared_ptr<encoder::REPRESENTATION> &rep) {
|
30 |
+
|
31 |
+
bar_num++;
|
32 |
+
}
|
33 |
+
int bar_num;
|
34 |
+
int track_num;
|
35 |
+
|
36 |
+
std::vector<int> tokens;
|
37 |
+
};
|
38 |
+
}
|
src/common/data_structures/track_type.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "midi.pb.h"
|
4 |
+
|
5 |
+
// START OF NAMESPACE
|
6 |
+
namespace data_structures {
|
7 |
+
|
8 |
+
std::map<midi::TRACK_TYPE,bool> TRACK_TYPE_IS_DRUM = {
|
9 |
+
{midi::AUX_DRUM_TRACK, true},
|
10 |
+
{midi::AUX_INST_TRACK, false},
|
11 |
+
{midi::STANDARD_TRACK, false},
|
12 |
+
{midi::STANDARD_DRUM_TRACK, true},
|
13 |
+
};
|
14 |
+
|
15 |
+
bool is_drum_track(int tt) {
|
16 |
+
return TRACK_TYPE_IS_DRUM[static_cast<midi::TRACK_TYPE>(tt)];
|
17 |
+
}
|
18 |
+
|
19 |
+
}
|
20 |
+
// END OF NAMESPACE
|
src/common/data_structures/train_config.cpp
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "train_config.h"
|
2 |
+
|
3 |
+
|
4 |
+
namespace data_structures {
|
5 |
+
|
6 |
+
TrainConfig::TrainConfig() {
|
7 |
+
num_bars = 4;
|
8 |
+
min_tracks = 1;
|
9 |
+
max_tracks = 12;
|
10 |
+
max_mask_percentage = 0.75;
|
11 |
+
use_microtiming = false;
|
12 |
+
microtiming = 0.9;
|
13 |
+
no_max_length = false;
|
14 |
+
resolution = 12;
|
15 |
+
delta_resolution = 1920;
|
16 |
+
decode_resolution = delta_resolution;
|
17 |
+
}
|
18 |
+
|
19 |
+
std::map<std::string, std::string> TrainConfig::ToJson() {
|
20 |
+
std::map<std::string, std::string> json_config;
|
21 |
+
json_config["num_bars"] = std::to_string(num_bars);
|
22 |
+
json_config["min_tracks"] = std::to_string(min_tracks);
|
23 |
+
json_config["max_tracks"] = std::to_string(max_tracks);
|
24 |
+
json_config["max_mask_percentage"] = std::to_string(max_mask_percentage);
|
25 |
+
json_config["use_microtiming"] = std::to_string((int)use_microtiming);
|
26 |
+
json_config["microtiming"] = std::to_string(microtiming);
|
27 |
+
json_config["no_max_length"] = std::to_string((int)no_max_length);
|
28 |
+
json_config["resolution"] = std::to_string(resolution);
|
29 |
+
json_config["decode_resolution"] = std::to_string(decode_resolution);
|
30 |
+
json_config["delta_resolution"] = std::to_string(delta_resolution);
|
31 |
+
return json_config;
|
32 |
+
}
|
33 |
+
|
34 |
+
void TrainConfig::FromJson(std::map<std::string, std::string>& json_config) {
|
35 |
+
num_bars = stoi(json_config["num_bars"]);
|
36 |
+
min_tracks = stoi(json_config["min_tracks"]);
|
37 |
+
max_tracks = stoi(json_config["max_tracks"]);
|
38 |
+
max_mask_percentage = stof(json_config["max_mask_percentage"]);
|
39 |
+
microtiming = stoi(json_config["microtiming"]);
|
40 |
+
use_microtiming = (bool)stoi(json_config["use_microtiming"]);
|
41 |
+
no_max_length = (bool)stoi(json_config["no_max_length"]);
|
42 |
+
resolution = stoi(json_config["resolution"]);
|
43 |
+
decode_resolution = stoi(json_config["decode_resolution"]);
|
44 |
+
delta_resolution = stoi(json_config["delta_resolution"]);
|
45 |
+
}
|
46 |
+
}
|