diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8f8f838348bac47e383659ced939d1d6467dded1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,87 @@ +# .gitignore + +build.sh + +# feature_extraction stuff +python_experiment_scripts/**/*.json +python_experiment_scripts/**/*.csv +python_experiment_scripts/**/*.pdf +python_experiment_scripts/**/*.numbers +python_experiment_scripts/**/*.txt +python_experiment_scripts/examples/**/*.mid + +python_scripts_for_testing/examples/*.mid +python_scripts_for_testing/examples/ + +# Ignore compilation folders / dependencies +#midifile/ +libtorch/ +libtorch_cpu/ +#pybind11/ +build/ +python_lib/ +CMakeFiles/ +CMakeScripts/ +CMakeCache.txt +cmake_install.cmake +Debug/ + +# Ignore the libtorch directory +/libraries/libtorch +/libraries/libtorch_cpu +#/libraries/pybind11 +#/libraries/midifile + +/python_lib/ +/python_scripts/training_files + +# ignore dataset files +*.arr +*.arr.header + +# ignore dstore +.DS_Store + +# Ignore the model.pt file +**/*.pt +*.pt + +#ignore visual studio metadata +.vs/ +.vscode + +#ignore build directory +/out/ + +*.out + +#ignore random files +2539bytes.txt +CMakeLists_CPP.txt + +/libraries/protobuf/.vs/ +/libraries/protobuf/out/ + +/libraries/torch/.vs/ +/libraries/torch/out/ + +# exceptions +!python_experiment_scripts/onset_density_test.pdf +!python_experiment_scripts/onset_density_test.json +!python_experiment_scripts/onset_polyphony_test.json +!python_experiment_scripts/musicmap_genre_data.json + +!python_experiment_scripts/onset_polyphony_test_mono.json +!python_experiment_scripts/onset_polyphony_test_poly.json +!python_experiment_scripts/onset_polyphony_v_original_test_mono.json +!python_experiment_scripts/onset_polyphony_v_original_test_poly.json + +# pycache folders +*/__pycache__/* + +#training outputs +*/checkpoints/* +*/logs/* +*/midigpt.cpython-38-x86_64-linux-gnu.so + +notes/* \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..e128f295d1735af98277bf72f1b8456461943889 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "libraries/pybind11"] + path = libraries/pybind11 + url = https://github.com/pybind/pybind11.git +[submodule "libraries/midifile"] + path = libraries/midifile + url = https://github.com/craigsapp/midifile.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7723fb1cacb0981d9c67a29728d7ad949af89da8 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,87 @@ +cmake_minimum_required(VERSION 3.8) + +SET(CMAKE_CXX_STANDARD 20) +SET(CMAKE_CXX_STANDARD_REQUIRED ON) +SET(CMAKE_POSITION_INDEPENDENT_CODE ON) +#we add the following line to fix a linkage issue between torch and midifile +#https://stackoverflow.com/questions/68922557/c-linker-error-undefined-reference-when-linking-package-libtorch-and-shared +#add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +project(midigpt) + +option(compute_canada "Build for Compute Canada" OFF) +option(mac_os "Build for Mac OS" OFF) +option(no_torch "No Torch" OFF) +option(no_pybind "No Pybind" OFF) +option(trace "Trace" OFF) + +#Find the necessary packages to be able to link the libraries correctly +find_package(Protobuf REQUIRED) +include_directories(${Protobuf_INCLUDE_DIRS}) + +if(no_torch) + add_definitions(-DNO_TORCH) +endif() + +if(NOT no_torch) + if(mac_os) + message("USING PYTHON PYTORCH INSTEAD") + else() + set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libraries/libtorch/") + endif() + find_package(Torch REQUIRED) + + # This is necessary to avoid a symbol linkage error https://github.com/pytorch/pytorch/issues/38122 + # https://github.com/DeepVAC/libdeepvac/blob/master/python/CMakeLists.txt + find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") +endif() + +if(compute_canada) + include_directories("/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx512/Core/python/3.8.2/include/python3.8") +endif() + +#Add the directories of libraries so the project can CMake them too +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/protobuf) +if(NOT no_torch) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/torch) +endif() +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/pybind11) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libraries/midifile) + +#https://stackoverflow.com/questions/8934295/add-source-in-a-subdirectory-to-a-cmake-project/54285898#54285898 +#https://crascit.com/2016/01/31/enhanced-source-file-handling-with-target_sources/ + + +set(SRCS + src/common/data_structures/train_config.cpp + src/dataset_creation/compression/lz4.c + src/dataset_creation/dataset_manipulation/bytes_to_file.cpp + src/common/encoder/encoder_all.h + src/lib.cpp +) +PYBIND11_ADD_MODULE(midigpt ${SRCS}) + +#Adding include folders of libraries to our target so we can reference them with #include +#Add subdirectory adds those to main project so they can be CMAKEd. Include dirs allows us to reference functions in main. +target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/protobuf/include) +if (NOT no_torch) + target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/torch/include) +endif() +target_include_directories(midigpt PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/libraries/midifile/include) + +#Linking all the libraries +target_link_libraries(midigpt PRIVATE midigpt_proto) #Our protobuf custom library +target_link_libraries(midigpt PRIVATE midifile) +if (NOT no_torch) + target_link_libraries(midigpt PRIVATE midigpt_torch) #Our torch custom library + #This is necessary to avoid a symbol linkage error https://github.com/pytorch/pytorch/issues/38122 + target_link_libraries(midigpt PRIVATE "${TORCH_LIBRARIES}" ${TORCH_PYTHON_LIBRARY}) +endif() + +if (trace) + add_library(tracer STATIC src/trace.cpp) + target_link_libraries(midigpt PRIVATE tracer) + target_compile_options(midigpt PRIVATE -Wall -Wextra -Wpedantic -finstrument-functions) +elseif(NOT WIN32) + target_compile_options(midigpt PRIVATE -Wall -Wextra -Wpedantic) +endif() diff --git a/CMakeSettings.json b/CMakeSettings.json new file mode 100644 index 0000000000000000000000000000000000000000..36482804d2487e2ecd6cd71da1e600a649c77cf9 --- /dev/null +++ b/CMakeSettings.json @@ -0,0 +1,28 @@ +{ + "configurations": [ + { + "name": "x64-Debug", + "generator": "Ninja", + "configurationType": "Debug", + "inheritEnvironments": [ "msvc_x64_x64" ], + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "" + }, + { + "name": "WSL-GCC-Debug", + "generator": "Ninja", + "configurationType": "Debug", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeExecutable": "cmake", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "inheritEnvironments": [ "linux_x64" ], + "wslPath": "${defaultWSLPath}" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index f41a2d3d8b93d61ce4499f906b314c83e52da1cb..bc96265665ddce7b5863afbbbf5db76cb4ba0ea6 100644 --- a/README.md +++ b/README.md @@ -9,3 +9,179 @@ short_description: MIDI-GPT-inference-docker --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + + +[![N|Solid](https://drive.google.com/uc?export=view&id=1u4xiWN3s0PAii8zn3-qxJ7wn35tBOypY)](https://metacreation.net/category/projects/) + +# MIDI-GPT Guide + +This is the repository for MIDI-GPT, a generative system based on the Transformer architecture that is designed for computer-assisted music composition workflows. This work was presented at the 39th Annual AAAI Conference in Philadelphia, USA in this [paper](https://arxiv.org/abs/2501.17011) + +# Using MIDI-GPT + +## Installation + +To successfully install the midigpt python library, use the script ```midigpt_setup_helper.sh```. You may first download this script on its own and run it, which will clone the repository and build the library. Below is an example of the usage: + +```sh +bash midigpt_setup_helper.sh -d midigpt_dir +``` + +>**Note:** Python 3.8 is required for the library +>**Note:** If you're building on mac, use the ```-m```argument. + +## Inference + +Once downloaded, MIDI-GPT is ready to use. ```python_scripts_for_testing/pythoninferencetest.py``` is an example of using MIDI-GPT. In summary, three objects need to be created before sampling: +- Piece: Load the MIDI file into a JSON representation of the MIDI piece +- Status: This dict indicates the sampling process that is desired (on which tracks, continuation/resampling/infilling, etc.) as well as attribute control values +- Param: This dict indicates sampling parameters such as temperature or number of generated bars per step + +You must provide an input MIDI file, the checkpoint model file and an optional output MIDI file. Our model is provided in the ```models/model.zip``` file. + +Then, using the ```midigpt``` Python API, call the sample function with these objects as arguments. After sampling, the result can then be converted and saved into a MIDI file. + +# Training MIDI-GPT + +Training the model was done on computing clusters on Compute Canada, therefore the training scripts are tailored to this platform but may easily be adapted to similar platforms. Training was done using the GigaMIDI dataset, first serialzed into a compressed file using ```create_dataset_compute_canada.sh``` and ```python_scripts/create_dataset.py```. The training was executed using the ```python_scripts/train.py```. Finally, the model weights file is converted from the training checkpoint using ```convert.py```. + +If you're unfamiliar with Compute Canada, make sure to check the introductory .md [here](). + +## Installation - Cedar and Niagara +0. You might want to allocate an interactive session with salloc: + +>**Note:** You DON'T need to do this in Niagara. + +```sh +salloc --time=3:0:0 --nodes 1 --cpus-per-task 32 --mem=128000 --account=user +``` + +1. First, make sure to clone the MMM_API into a folder in your CC machine: +```sh +https://github.com/Metacreation-Lab/MIDI-GPT +``` +2. Then we must load the standard environments and some dependencies: + +>**Note:** If you're building in Niagara, load this first: +```sh +module load CCEnv arch/avx512 +``` +Then proceed to load the rest (If you're in Cedar, start from here): +```sh +module load StdEnv/2020 +module load cmake/3.23.1 +module load gcc/11.3.0 +module load protobuf/3.12.3 +module load python/3.8.2 +``` +3. Then we must create an environment and activate it: +```sh +virtualenv --no-download ./ENV # ENV is the name of the environment +source ./ENV/bin/activate +pip install --no-index --upgrade pip + +# For training only +pip install torch==1.13.0+computecanada +pip install transformers==4.26.1+computecanada +``` +4. Finally, just call the bash script with the correct argument: +```sh +bash create_python_library.sh --test_build --compute_canada +``` +Or if you are planning to just train the model, add the argument excluding to torch library required only for inference: +```sh +bash create_python_library.sh --no_torch --compute_canada +``` +5. To test the library imports for training, run the train.py script by importing it: +```sh +cd python_scripts +python3 -c "import train" +``` +> **Note:** A helper script ```midigpt_setup_helper.sh``` does all these steps autmoatically (for training or inference). Download it individually and run it where you wish to clone the repository. +> **Note:** If you run the code without the --test_build flag, it will still compile and create the python library but it won't test it with the current model in production. +> **Note:** The other flag (--compute_canada) is necesary to build the code properly. + +That's it! + +Everything should get installed correctly in your python environment! If you log out and back in to CC make sure to activate the environment in which you installed the API. + +## Training + +### Dataset Building + +In order to train a new model, you must first build a dataset. You can upload the files you need using Globus (check the CC [guide]()). + +> **Note**: Remember that to copy from the shared folder to your own folders you must use absolute paths. + +The data should be organized in a way where all midi files are contained within three folders ```train```, ```test```, and ```valid```. Further directories can be used to organize the midi files as long as they are within these three directories. + +If your dataset is a single folder containing all the midi files, we provide a helper script that automatically slits the dataset to 80%-10%-10%. Simply modify ```data_split.sh``` to match your cas and run. + +Once you have the folder with the data, run the following command +```sh +sh create_dataset_compute_canada.sh --root_dir= --encoding= --data_dir= --output= +``` +where: +- `````` is the root folder where the midigpt repository folder is located +- `````` is the conder to use. We suggest using ```EXPRESSIVE_ENCODER``` +- `````` is the dataset folder containing the three ```train```, ```test```, and ```valid``` folders. +- `````` is the location of the ouptt ```.arr``` file. The resulting file while be ```_NUM_BARS=_RESOLUTION_.arr``` +>**Note:** If you are on Compute Canada, we suggest you run these commands through an sbatch job as they can take some time. + +### Training a Model + +To train a model, run the train.py file. Different lab members have managed to set the paths differently. What works for me is to use global paths. An example would be: +```sh +python train.py --arch gpt2 --config /home/user/scratch/TRAINING-master/config/gpt2_tiny.json --encoding EXPRESSIVE_ENCODER --ngpu 4 --dataset /home/user/scratch/test_NUM_BARS=4_OPZ_False.arr --batch_size 32 --label DELETE_ME +``` + +### Running Jobs + +To read the CC documentation, cick [here](https://docs.alliancecan.ca/wiki/Running_jobs). You can run small snippets of code to test things out without allocating any resources. However, to train a model or perform any time/resource consuming task, you must schedule a job. A list of different types of job scheduling will be added here. + +#### Interactive Jobs +You can start an interactive session on a compute node with salloc. +```sh +salloc --time=3:0:0 --nodes 1 --cpus-per-task 32 --mem=128000 --account=user +``` + +#### Scheduled jobs (use this for training) +For time-expensive tasks it is better to create a bash file and submit a job with sbatch: +```sh +sbatch simple_job.sh +``` + +Here is an example of the contents of a bash file to submit a midigpt training job: +```sh +#!/bin/bash +#SBATCH --gres=gpu:v100l:4 +#SBATCH --cpus-per-task=32 +#SBATCH --exclusive +#SBATCH --mem=0 +#SBATCH --time=2-23:00 +#SBATCH --account=user +#SBATCH --mail-user USERNAME@domain.org <---- MAKE SURE TO PUT YOUR EMAIL +#SBATCH --mail-type ALL +#SBATCH --output=CCLOG/FILENAME.out <---- MAKE SURE TO CHANGE THE NAME OF THE FILE + +source $SCRATCH/PY_3610/bin/activate <---- THIS IS THE DIRECTORY TO THE ENV WHERE YOU HAVE THE midigpt_api INSTALLED +cd $SCRATCH/MMM_TRAINING-master +module load StdEnv/2020 protobuf python/3.6.10 +source $SCRATCH/PY_3610/bin/activate <---- SAME HERE, MAKE SURE THE DIRECTORY IS PLACED CORRECTLY +python train.py --arch reformer --config /home/user/scratch/MMM_TRAINING-master/config/reformer.json --encoding EXPRESSIVE_ENCODER --ngpu 4 --dataset /home/user/scratch/dataset_NUM_BARS=4.arr --batch_size 32 --label DELETE_ME +``` + +In this case we are using 4 v1001 GPUs (**gres** argument) and we're asking for 2 days and 23 hours of time to run the job (**time** argument). + +#### Check jobs and eliminate session +To show all the users +```sh +who -u +``` + +To kill all the sessions +```sh +pkill -u username +``` + + diff --git a/create_dataset_compute_canada.sh b/create_dataset_compute_canada.sh new file mode 100644 index 0000000000000000000000000000000000000000..d8816732d8aa51fa246220ebdab975e46ce3e1f0 --- /dev/null +++ b/create_dataset_compute_canada.sh @@ -0,0 +1,77 @@ +#!/bin/bash +#SBATCH --cpus-per-task=32 +#SBATCH --time=10:00:00 + +root_dir="" # Model directory to replace MODEL_NAME +data_dir="" # Data directory to replace DATA_DIR +encoding="" # Encoding to replace ENCODING +output="" # Output to replace OUTPUT +metadata="" +test="no" +zip="no" +res="12" +max="-1" + +# Parse arguments +for arg in "$@" +do + case $arg in + --root_dir=*) + root_dir="${arg#*=}" + shift # Remove --root_dir= from processing + ;; + --metadata=*) + metadata="${arg#*=}" + shift # Remove --metadata= from processing + ;; + --res=*) + res="${arg#*=}" + shift # Remove --metadata= from processing + ;; + --test=*) + test="${arg#*=}" + shift # Remove --metadata= from processing + ;; + --data_dir=*) + data_dir="${arg#*=}" + shift # Remove --data_dir= from processing + ;; + --encoding=*) + encoding="${arg#*=}" + shift # Remove --encoding= from processing + ;; + --output=*) + output="${arg#*=}" + shift # Remove --output= from processing + ;; + --zip=*) + zip="${arg#*=}" + shift # Remove --output= from processing + ;; + --max=*) + max="${arg#*=}" + shift # Remove --output= from processing + ;; + esac +done + +module load CCEnv arch/avx512 +module load StdEnv/2020 +module load cmake/3.23.1 +module load gcc/11.3.0 +module load protobuf/3.12.3 +module load python/3.8.2 + +mkdir -p $root_dir/CCLOG +source $root_dir/venv/bin/activate + +cp $root_dir/MIDI-GPT/python_lib/midigpt.cpython-38-x86_64-linux-gnu.so $root_dir/MIDI-GPT/python_scripts + +python3 $root_dir/MIDI-GPT/python_scripts/create_dataset.py --nthreads 40 --max_size $max --data_dir $data_dir --encoding $encoding --output $output --metadata $metadata --test $test --expressive --resolution $res + +if [[ "$zip" == "yes" ]] +then + cd $output + cd ../ + zip -r EXPRESSIVE_GIGAMIDI_24_1920.zip $output +fi \ No newline at end of file diff --git a/create_python_library.sh b/create_python_library.sh new file mode 100644 index 0000000000000000000000000000000000000000..55474feff1b0f9fed0cb6e1a8a997649a7f10351 --- /dev/null +++ b/create_python_library.sh @@ -0,0 +1,232 @@ +#!/bin/bash + +compute_canada=false +mac_os=false +cpu=false +no_torch=false +niagara=false +env_name="venv" +parent_dir="midigpt_workspace" # default parent directory name +cuda=false +trace=false + + +# Parse arguments +for ((i=1;i<=$#;i++)); do + case ${!i} in + --trace) + trace=true + ;; + --compute_canada) + compute_canada=true + ;; + --cpu) + cpu=true + ;; + --mac_os) + mac_os=true + ;; + --no_torch) + no_torch=true + ;; + --niagara) + niagara=true + ;; + --env_name) + i=$((i+1)) + env_name=${!i} + ;; + -n=*|--name=*) # new parent directory name option + parent_dir="${!i#*=}" + shift + ;; + --cuda) + if $no_torch; then + echo "Cannot use --cuda and --no_torch at the same time." + exit 1 + fi + cuda=true + ;; + *) + echo "Unknown option ${!i}" + exit 1 + ;; + esac +done + + +# Get the current directory name +dir_name=$(basename `pwd`) + +# Get the parent directory name +parent_dir_name=$(basename $(dirname `pwd`)) + +if [ "$parent_dir_name" != "$parent_dir" ]; then + + # Go to the parent directory + cd .. + + # Create the new parent directory + mkdir $parent_dir + + # Move the old directory into the new parent directory + mv $dir_name $parent_dir/ + + # Change to the old directory, which is now inside the parent directory + cd $parent_dir/$dir_name +fi + +# Load modules if we are in compute_canada and niagara +if $compute_canada; then + if $niagara; then + module load CCEnv arch/avx512 + fi + module load StdEnv/2020 + module load cmake/3.23.1 + module load gcc/11.3.0 + module load protobuf/3.12.3 + module load python/3.8.2 + + mkdir ../CCLOG +fi + +# Environment creation +if [[ -n ../$env_name ]]; then + if [[ -d ../$env_name ]]; then + echo "Environment $env_name already exists, activating it..." + else + echo "Environment $env_name does not exist, creating it..." + if $compute_canada; then + virtualenv ../$env_name + else + python3 -m venv ../$env_name + fi + fi +fi + +source ../$env_name/bin/activate + +# Install requirements +pip install -r pip_requirements/common_requirements.txt + +if $compute_canada; then + pip install -r pip_requirements/create_dataset_requirements.txt +fi +if $mac_os; then + pip install -r pip_requirements/inference_requirements.txt +fi + +if $compute_canada && ! $niagara; then # anf if no torch + pip install -r pip_requirements/train_requirements.txt +fi + +#deactivate + +# Set CMake flags based on command line arguments +cmake_flags="" +if $compute_canada; then + cmake_flags="$cmake_flags -Dcompute_canada=ON" +fi + +if $no_torch; then + cmake_flags="$cmake_flags -Dno_torch=ON" +fi + +if $trace; then + cmake_flags="$cmake_flags -Dtrace=ON" +fi + +# Code to check if libtorch and pybind11 are already downloaded +if ! $no_torch; then + libtorch_path="libraries/libtorch" + libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip" + if $cuda; then + libtorch_url="https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip" + fi +fi + +pybind11_path="libraries/pybind11" +midifile_path="libraries/midifile" + + +pybind11_url="https://github.com/pybind/pybind11.git" +midifile_url="https://github.com/craigsapp/midifile" + +if ! $no_torch; then + if $mac_os; then + libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.0.1.zip" + fi + + if $cpu; then + libtorch_url="https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip" + fi + + # Check if libtorch folder exists and is not empty + if [ ! -d "$libtorch_path" ] || [ -z "$(ls -A "$libtorch_path")" ]; then + echo "libtorch folder does not exist or is empty. Downloading and extracting..." + mkdir -p "$libtorch_path" + curl -L "$libtorch_url" -o libtorch.zip + unzip -q libtorch.zip -d libraries/ + rm libtorch.zip + echo "libtorch downloaded and extracted." + else + echo "libtorch folder exists and is not empty. No need to download." + fi +fi + +# Check if pybind11 folder exists and is not empty +if [ ! -d "$pybind11_path" ] || [ -z "$(ls -A "$pybind11_path")" ]; then + echo "pybind11 folder does not exist or is empty. Cloning the repository..." + mkdir -p libraries + git clone "$pybind11_url" "$pybind11_path" + echo "pybind11 downloaded." + cd libraries/pybind11 + git reset --hard 5ccb9e4 + cd ../../ + echo "pybind11 reset to working build" +else + echo "pybind11 folder exists and is not empty. No need to download." +fi + +# Check if midifile folder exists and is not empty +if [ ! -d "$midifile_path" ] || [ -z "$(ls -A "$midifile_path")" ]; then + echo "midifile folder does not exist or is empty. Cloning the repository..." + mkdir -p libraries + git clone "$midifile_url" "$midifile_path" + echo "midifile downloaded." + cd libraries/midifile + git reset --hard 838c62c + cd ../../ + echo "midifile reset to working build" +else + echo "midifile folder exists and is not empty. No need to download." +fi + +# Middle section of the script to build the python library +rm -rf ./python_lib +mkdir ./python_lib +rm -rf ./libraries/protobuf/build +mkdir ./libraries/protobuf/build + +cd ./libraries/protobuf/src +protoc --cpp_out ../build *.proto +cd ../../.. + +cd ./python_lib + +if $mac_os; then + cmake $cmake_flags .. -Dmac_os=ON -DCMAKE_PREFIX_PATH=$(python3 -c 'import torch;print(torch.utils.cmake_prefix_path)') +else + cmake $cmake_flags .. +fi +make +python3 -c "import midigpt; print('midigpt python library built successfully')" + +cd .. +if $compute_canada; then + dos2unix create_dataset_compute_canada.sh + dos2unix train_dataset.sh +fi +cd ./python_lib + +cd .. diff --git a/data_split.sh b/data_split.sh new file mode 100644 index 0000000000000000000000000000000000000000..bcc470172904897d79716287d5bbacc84af902f4 --- /dev/null +++ b/data_split.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +root=$SCRATCH/datasets/GigaMIDI_Cleaned/Cleaned_Ver_EP_Class-GigaMIDI/Cleaned_GigaMIDI +new_root=$SCRATCH/datasets/GigaMIDI_Cleaned/Cleaned_Ver_EP_Class-GigaMIDI/Cleaned_GigaMIDI_Split +parent=$SCRATCH/workspace_train/parent_dir + +mkdir -p $new_root +cd $new_root +mkdir -p train +mkdir -p test +mkdir -p valid + +cd $SCRATCH + +source $parent/venv/bin/activate +python $parent/MIDI-GPT/python_scripts/data_split.py $root $new_root \ No newline at end of file diff --git a/include/dataset_creation/compression/lz4.h b/include/dataset_creation/compression/lz4.h new file mode 100644 index 0000000000000000000000000000000000000000..deca3b455319813d071e2305fba3da854c4a8613 --- /dev/null +++ b/include/dataset_creation/compression/lz4.h @@ -0,0 +1,764 @@ +/* + * LZ4 - Fast LZ compression algorithm + * Header File + * Copyright (C) 2011-present, Yann Collet. + + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 homepage : http://www.lz4.org + - LZ4 source repository : https://github.com/lz4/lz4 +*/ +#if defined (__cplusplus) +extern "C" { +#endif + +#ifndef LZ4_H_2983827168210 +#define LZ4_H_2983827168210 + +/* --- Dependency --- */ +#include /* size_t */ + + +/** + Introduction + + LZ4 is lossless compression algorithm, providing compression speed >500 MB/s per core, + scalable with multi-cores CPU. It features an extremely fast decoder, with speed in + multiple GB/s per core, typically reaching RAM speed limits on multi-core systems. + + The LZ4 compression library provides in-memory compression and decompression functions. + It gives full buffer control to user. + Compression can be done in: + - a single step (described as Simple Functions) + - a single step, reusing a context (described in Advanced Functions) + - unbounded multiple steps (described as Streaming compression) + + lz4.h generates and decodes LZ4-compressed blocks (doc/lz4_Block_format.md). + Decompressing such a compressed block requires additional metadata. + Exact metadata depends on exact decompression function. + For the typical case of LZ4_decompress_safe(), + metadata includes block's compressed size, and maximum bound of decompressed size. + Each application is free to encode and pass such metadata in whichever way it wants. + + lz4.h only handle blocks, it can not generate Frames. + + Blocks are different from Frames (doc/lz4_Frame_format.md). + Frames bundle both blocks and metadata in a specified manner. + Embedding metadata is required for compressed data to be self-contained and portable. + Frame format is delivered through a companion API, declared in lz4frame.h. + The `lz4` CLI can only manage frames. +*/ + +/*^*************************************************************** +* Export parameters +*****************************************************************/ +/* +* LZ4_DLL_EXPORT : +* Enable exporting of functions when building a Windows DLL +* LZ4LIB_VISIBILITY : +* Control library symbols visibility. +*/ +#ifndef LZ4LIB_VISIBILITY +# if defined(__GNUC__) && (__GNUC__ >= 4) +# define LZ4LIB_VISIBILITY __attribute__ ((visibility ("default"))) +# else +# define LZ4LIB_VISIBILITY +# endif +#endif +#if defined(LZ4_DLL_EXPORT) && (LZ4_DLL_EXPORT==1) +# define LZ4LIB_API __declspec(dllexport) LZ4LIB_VISIBILITY +#elif defined(LZ4_DLL_IMPORT) && (LZ4_DLL_IMPORT==1) +# define LZ4LIB_API __declspec(dllimport) LZ4LIB_VISIBILITY /* It isn't required but allows to generate better code, saving a function pointer load from the IAT and an indirect jump.*/ +#else +# define LZ4LIB_API LZ4LIB_VISIBILITY +#endif + +/*------ Version ------*/ +#define LZ4_VERSION_MAJOR 1 /* for breaking interface changes */ +#define LZ4_VERSION_MINOR 9 /* for new (non-breaking) interface capabilities */ +#define LZ4_VERSION_RELEASE 2 /* for tweaks, bug-fixes, or development */ + +#define LZ4_VERSION_NUMBER (LZ4_VERSION_MAJOR *100*100 + LZ4_VERSION_MINOR *100 + LZ4_VERSION_RELEASE) + +#define LZ4_LIB_VERSION LZ4_VERSION_MAJOR.LZ4_VERSION_MINOR.LZ4_VERSION_RELEASE +#define LZ4_QUOTE(str) #str +#define LZ4_EXPAND_AND_QUOTE(str) LZ4_QUOTE(str) +#define LZ4_VERSION_STRING LZ4_EXPAND_AND_QUOTE(LZ4_LIB_VERSION) + +LZ4LIB_API int LZ4_versionNumber (void); /**< library version number; useful to check dll version */ +LZ4LIB_API const char* LZ4_versionString (void); /**< library version std::string; useful to check dll version */ + + +/*-************************************ +* Tuning parameter +**************************************/ +/*! + * LZ4_MEMORY_USAGE : + * Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.) + * Increasing memory usage improves compression ratio. + * Reduced memory usage may improve speed, thanks to better cache locality. + * Default value is 14, for 16KB, which nicely fits into Intel x86 L1 cache + */ +#ifndef LZ4_MEMORY_USAGE +# define LZ4_MEMORY_USAGE 14 +#endif + + +/*-************************************ +* Simple Functions +**************************************/ +/*! LZ4_compress_default() : + * Compresses 'srcSize' bytes from buffer 'src' + * into already allocated 'dst' buffer of size 'dstCapacity'. + * Compression is guaranteed to succeed if 'dstCapacity' >= LZ4_compressBound(srcSize). + * It also runs faster, so it's a recommended setting. + * If the function cannot compress 'src' into a more limited 'dst' budget, + * compression stops *immediately*, and the function result is zero. + * In which case, 'dst' content is undefined (invalid). + * srcSize : max supported value is LZ4_MAX_INPUT_SIZE. + * dstCapacity : size of buffer 'dst' (which must be already allocated) + * @return : the number of bytes written into buffer 'dst' (necessarily <= dstCapacity) + * or 0 if compression fails + * Note : This function is protected against buffer overflow scenarios (never writes outside 'dst' buffer, nor read outside 'source' buffer). + */ +LZ4LIB_API int LZ4_compress_default(const char* src, char* dst, int srcSize, int dstCapacity); + +/*! LZ4_decompress_safe() : + * compressedSize : is the exact complete size of the compressed block. + * dstCapacity : is the size of destination buffer (which must be already allocated), presumed an upper bound of decompressed size. + * @return : the number of bytes decompressed into destination buffer (necessarily <= dstCapacity) + * If destination buffer is not large enough, decoding will stop and output an error code (negative value). + * If the source stream is detected malformed, the function will stop decoding and return a negative result. + * Note 1 : This function is protected against malicious data packets : + * it will never writes outside 'dst' buffer, nor read outside 'source' buffer, + * even if the compressed block is maliciously modified to order the decoder to do these actions. + * In such case, the decoder stops immediately, and considers the compressed block malformed. + * Note 2 : compressedSize and dstCapacity must be provided to the function, the compressed block does not contain them. + * The implementation is free to send / store / derive this information in whichever way is most beneficial. + * If there is a need for a different format which bundles together both compressed data and its metadata, consider looking at lz4frame.h instead. + */ +LZ4LIB_API int LZ4_decompress_safe (const char* src, char* dst, int compressedSize, int dstCapacity); + + +/*-************************************ +* Advanced Functions +**************************************/ +#define LZ4_MAX_INPUT_SIZE 0x7E000000 /* 2 113 929 216 bytes */ +#define LZ4_COMPRESSBOUND(isize) ((unsigned)(isize) > (unsigned)LZ4_MAX_INPUT_SIZE ? 0 : (isize) + ((isize)/255) + 16) + +/*! LZ4_compressBound() : + Provides the maximum size that LZ4 compression may output in a "worst case" scenario (input data not compressible) + This function is primarily useful for memory allocation purposes (destination buffer size). + Macro LZ4_COMPRESSBOUND() is also provided for compilation-time evaluation (stack memory allocation for example). + Note that LZ4_compress_default() compresses faster when dstCapacity is >= LZ4_compressBound(srcSize) + inputSize : max supported value is LZ4_MAX_INPUT_SIZE + return : maximum output size in a "worst case" scenario + or 0, if input size is incorrect (too large or negative) +*/ +LZ4LIB_API int LZ4_compressBound(int inputSize); + +/*! LZ4_compress_fast() : + Same as LZ4_compress_default(), but allows selection of "acceleration" factor. + The larger the acceleration value, the faster the algorithm, but also the lesser the compression. + It's a trade-off. It can be fine tuned, with each successive value providing roughly +~3% to speed. + An acceleration value of "1" is the same as regular LZ4_compress_default() + Values <= 0 will be replaced by ACCELERATION_DEFAULT (currently == 1, see lz4.c). +*/ +LZ4LIB_API int LZ4_compress_fast (const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + + +/*! LZ4_compress_fast_extState() : + * Same as LZ4_compress_fast(), using an externally allocated memory space for its state. + * Use LZ4_sizeofState() to know how much memory must be allocated, + * and allocate it on 8-bytes boundaries (using `malloc()` typically). + * Then, provide this buffer as `void* state` to compression function. + */ +LZ4LIB_API int LZ4_sizeofState(void); +LZ4LIB_API int LZ4_compress_fast_extState (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + + +/*! LZ4_compress_destSize() : + * Reverse the logic : compresses as much data as possible from 'src' buffer + * into already allocated buffer 'dst', of size >= 'targetDestSize'. + * This function either compresses the entire 'src' content into 'dst' if it's large enough, + * or fill 'dst' buffer completely with as much data as possible from 'src'. + * note: acceleration parameter is fixed to "default". + * + * *srcSizePtr : will be modified to indicate how many bytes where read from 'src' to fill 'dst'. + * New value is necessarily <= input value. + * @return : Nb bytes written into 'dst' (necessarily <= targetDestSize) + * or 0 if compression fails. +*/ +LZ4LIB_API int LZ4_compress_destSize (const char* src, char* dst, int* srcSizePtr, int targetDstSize); + + +/*! LZ4_decompress_safe_partial() : + * Decompress an LZ4 compressed block, of size 'srcSize' at position 'src', + * into destination buffer 'dst' of size 'dstCapacity'. + * Up to 'targetOutputSize' bytes will be decoded. + * The function stops decoding on reaching this objective, + * which can boost performance when only the beginning of a block is required. + * + * @return : the number of bytes decoded in `dst` (necessarily <= dstCapacity) + * If source stream is detected malformed, function returns a negative result. + * + * Note : @return can be < targetOutputSize, if compressed block contains less data. + * + * Note 2 : this function features 2 parameters, targetOutputSize and dstCapacity, + * and expects targetOutputSize <= dstCapacity. + * It effectively stops decoding on reaching targetOutputSize, + * so dstCapacity is kind of redundant. + * This is because in a previous version of this function, + * decoding operation would not "break" a sequence in the middle. + * As a consequence, there was no guarantee that decoding would stop at exactly targetOutputSize, + * it could write more bytes, though only up to dstCapacity. + * Some "margin" used to be required for this operation to work properly. + * This is no longer necessary. + * The function nonetheless keeps its signature, in an effort to not break API. + */ +LZ4LIB_API int LZ4_decompress_safe_partial (const char* src, char* dst, int srcSize, int targetOutputSize, int dstCapacity); + + +/*-********************************************* +* Streaming Compression Functions +***********************************************/ +typedef union LZ4_stream_u LZ4_stream_t; /* incomplete type (defined later) */ + +LZ4LIB_API LZ4_stream_t* LZ4_createStream(void); +LZ4LIB_API int LZ4_freeStream (LZ4_stream_t* streamPtr); + +/*! LZ4_resetStream_fast() : v1.9.0+ + * Use this to prepare an LZ4_stream_t for a new chain of dependent blocks + * (e.g., LZ4_compress_fast_continue()). + * + * An LZ4_stream_t must be initialized once before usage. + * This is automatically done when created by LZ4_createStream(). + * However, should the LZ4_stream_t be simply declared on stack (for example), + * it's necessary to initialize it first, using LZ4_initStream(). + * + * After init, start any new stream with LZ4_resetStream_fast(). + * A same LZ4_stream_t can be re-used multiple times consecutively + * and compress multiple streams, + * provided that it starts each new stream with LZ4_resetStream_fast(). + * + * LZ4_resetStream_fast() is much faster than LZ4_initStream(), + * but is not compatible with memory regions containing garbage data. + * + * Note: it's only useful to call LZ4_resetStream_fast() + * in the context of streaming compression. + * The *extState* functions perform their own resets. + * Invoking LZ4_resetStream_fast() before is redundant, and even counterproductive. + */ +LZ4LIB_API void LZ4_resetStream_fast (LZ4_stream_t* streamPtr); + +/*! LZ4_loadDict() : + * Use this function to reference a static dictionary into LZ4_stream_t. + * The dictionary must remain available during compression. + * LZ4_loadDict() triggers a reset, so any previous data will be forgotten. + * The same dictionary will have to be loaded on decompression side for successful decoding. + * Dictionary are useful for better compression of small data (KB range). + * While LZ4 accept any input as dictionary, + * results are generally better when using Zstandard's Dictionary Builder. + * Loading a size of 0 is allowed, and is the same as reset. + * @return : loaded dictionary size, in bytes (necessarily <= 64 KB) + */ +LZ4LIB_API int LZ4_loadDict (LZ4_stream_t* streamPtr, const char* dictionary, int dictSize); + +/*! LZ4_compress_fast_continue() : + * Compress 'src' content using data from previously compressed blocks, for better compression ratio. + * 'dst' buffer must be already allocated. + * If dstCapacity >= LZ4_compressBound(srcSize), compression is guaranteed to succeed, and runs faster. + * + * @return : size of compressed block + * or 0 if there is an error (typically, cannot fit into 'dst'). + * + * Note 1 : Each invocation to LZ4_compress_fast_continue() generates a new block. + * Each block has precise boundaries. + * Each block must be decompressed separately, calling LZ4_decompress_*() with relevant metadata. + * It's not possible to append blocks together and expect a single invocation of LZ4_decompress_*() to decompress them together. + * + * Note 2 : The previous 64KB of source data is __assumed__ to remain present, unmodified, at same address in memory ! + * + * Note 3 : When input is structured as a double-buffer, each buffer can have any size, including < 64 KB. + * Make sure that buffers are separated, by at least one byte. + * This construction ensures that each block only depends on previous block. + * + * Note 4 : If input buffer is a ring-buffer, it can have any size, including < 64 KB. + * + * Note 5 : After an error, the stream status is undefined (invalid), it can only be reset or freed. + */ +LZ4LIB_API int LZ4_compress_fast_continue (LZ4_stream_t* streamPtr, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + +/*! LZ4_saveDict() : + * If last 64KB data cannot be guaranteed to remain available at its current memory location, + * save it into a safer place (char* safeBuffer). + * This is schematically equivalent to a memcpy() followed by LZ4_loadDict(), + * but is much faster, because LZ4_saveDict() doesn't need to rebuild tables. + * @return : saved dictionary size in bytes (necessarily <= maxDictSize), or 0 if error. + */ +LZ4LIB_API int LZ4_saveDict (LZ4_stream_t* streamPtr, char* safeBuffer, int maxDictSize); + + +/*-********************************************** +* Streaming Decompression Functions +* Bufferless synchronous API +************************************************/ +typedef union LZ4_streamDecode_u LZ4_streamDecode_t; /* tracking context */ + +/*! LZ4_createStreamDecode() and LZ4_freeStreamDecode() : + * creation / destruction of streaming decompression tracking context. + * A tracking context can be re-used multiple times. + */ +LZ4LIB_API LZ4_streamDecode_t* LZ4_createStreamDecode(void); +LZ4LIB_API int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream); + +/*! LZ4_setStreamDecode() : + * An LZ4_streamDecode_t context can be allocated once and re-used multiple times. + * Use this function to start decompression of a new stream of blocks. + * A dictionary can optionally be set. Use NULL or size 0 for a reset order. + * Dictionary is presumed stable : it must remain accessible and unmodified during next decompression. + * @return : 1 if OK, 0 if error + */ +LZ4LIB_API int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize); + +/*! LZ4_decoderRingBufferSize() : v1.8.2+ + * Note : in a ring buffer scenario (optional), + * blocks are presumed decompressed next to each other + * up to the moment there is not enough remaining space for next block (remainingSize < maxBlockSize), + * at which stage it resumes from beginning of ring buffer. + * When setting such a ring buffer for streaming decompression, + * provides the minimum size of this ring buffer + * to be compatible with any source respecting maxBlockSize condition. + * @return : minimum ring buffer size, + * or 0 if there is an error (invalid maxBlockSize). + */ +LZ4LIB_API int LZ4_decoderRingBufferSize(int maxBlockSize); +#define LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize) (65536 + 14 + (maxBlockSize)) /* for static allocation; maxBlockSize presumed valid */ + +/*! LZ4_decompress_*_continue() : + * These decoding functions allow decompression of consecutive blocks in "streaming" mode. + * A block is an unsplittable entity, it must be presented entirely to a decompression function. + * Decompression functions only accepts one block at a time. + * The last 64KB of previously decoded data *must* remain available and unmodified at the memory position where they were decoded. + * If less than 64KB of data has been decoded, all the data must be present. + * + * Special : if decompression side sets a ring buffer, it must respect one of the following conditions : + * - Decompression buffer size is _at least_ LZ4_decoderRingBufferSize(maxBlockSize). + * maxBlockSize is the maximum size of any single block. It can have any value > 16 bytes. + * In which case, encoding and decoding buffers do not need to be synchronized. + * Actually, data can be produced by any source compliant with LZ4 format specification, and respecting maxBlockSize. + * - Synchronized mode : + * Decompression buffer size is _exactly_ the same as compression buffer size, + * and follows exactly same update rule (block boundaries at same positions), + * and decoding function is provided with exact decompressed size of each block (exception for last block of the stream), + * _then_ decoding & encoding ring buffer can have any size, including small ones ( < 64 KB). + * - Decompression buffer is larger than encoding buffer, by a minimum of maxBlockSize more bytes. + * In which case, encoding and decoding buffers do not need to be synchronized, + * and encoding ring buffer can have any size, including small ones ( < 64 KB). + * + * Whenever these conditions are not possible, + * save the last 64KB of decoded data into a safe buffer where it can't be modified during decompression, + * then indicate where this data is saved using LZ4_setStreamDecode(), before decompressing next block. +*/ +LZ4LIB_API int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int srcSize, int dstCapacity); + + +/*! LZ4_decompress_*_usingDict() : + * These decoding functions work the same as + * a combination of LZ4_setStreamDecode() followed by LZ4_decompress_*_continue() + * They are stand-alone, and don't need an LZ4_streamDecode_t structure. + * Dictionary is presumed stable : it must remain accessible and unmodified during decompression. + * Performance tip : Decompression speed can be substantially increased + * when dst == dictStart + dictSize. + */ +LZ4LIB_API int LZ4_decompress_safe_usingDict (const char* src, char* dst, int srcSize, int dstCapcity, const char* dictStart, int dictSize); + +#endif /* LZ4_H_2983827168210 */ + + +/*^************************************* + * !!!!!! STATIC LINKING ONLY !!!!!! + ***************************************/ + +/*-**************************************************************************** + * Experimental section + * + * Symbols declared in this section must be considered unstable. Their + * signatures or semantics may change, or they may be removed altogether in the + * future. They are therefore only safe to depend on when the caller is + * statically linked against the library. + * + * To protect against unsafe usage, not only are the declarations guarded, + * the definitions are hidden by default + * when building LZ4 as a shared/dynamic library. + * + * In order to access these declarations, + * define LZ4_STATIC_LINKING_ONLY in your application + * before including LZ4's headers. + * + * In order to make their implementations accessible dynamically, you must + * define LZ4_PUBLISH_STATIC_FUNCTIONS when building the LZ4 library. + ******************************************************************************/ + +#ifdef LZ4_STATIC_LINKING_ONLY + +#ifndef LZ4_STATIC_3504398509 +#define LZ4_STATIC_3504398509 + +#ifdef LZ4_PUBLISH_STATIC_FUNCTIONS +#define LZ4LIB_STATIC_API LZ4LIB_API +#else +#define LZ4LIB_STATIC_API +#endif + + +/*! LZ4_compress_fast_extState_fastReset() : + * A variant of LZ4_compress_fast_extState(). + * + * Using this variant avoids an expensive initialization step. + * It is only safe to call if the state buffer is known to be correctly initialized already + * (see above comment on LZ4_resetStream_fast() for a definition of "correctly initialized"). + * From a high level, the difference is that + * this function initializes the provided state with a call to something like LZ4_resetStream_fast() + * while LZ4_compress_fast_extState() starts with a call to LZ4_resetStream(). + */ +LZ4LIB_STATIC_API int LZ4_compress_fast_extState_fastReset (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + +/*! LZ4_attach_dictionary() : + * This is an experimental API that allows + * efficient use of a static dictionary many times. + * + * Rather than re-loading the dictionary buffer into a working context before + * each compression, or copying a pre-loaded dictionary's LZ4_stream_t into a + * working LZ4_stream_t, this function introduces a no-copy setup mechanism, + * in which the working stream references the dictionary stream in-place. + * + * Several assumptions are made about the state of the dictionary stream. + * Currently, only streams which have been prepared by LZ4_loadDict() should + * be expected to work. + * + * Alternatively, the provided dictionaryStream may be NULL, + * in which case any existing dictionary stream is unset. + * + * If a dictionary is provided, it replaces any pre-existing stream history. + * The dictionary contents are the only history that can be referenced and + * logically immediately precede the data compressed in the first subsequent + * compression call. + * + * The dictionary will only remain attached to the working stream through the + * first compression call, at the end of which it is cleared. The dictionary + * stream (and source buffer) must remain in-place / accessible / unchanged + * through the completion of the first compression call on the stream. + */ +LZ4LIB_STATIC_API void LZ4_attach_dictionary(LZ4_stream_t* workingStream, const LZ4_stream_t* dictionaryStream); + + +/*! In-place compression and decompression + * + * It's possible to have input and output sharing the same buffer, + * for highly contrained memory environments. + * In both cases, it requires input to lay at the end of the buffer, + * and decompression to start at beginning of the buffer. + * Buffer size must feature some margin, hence be larger than final size. + * + * |<------------------------buffer--------------------------------->| + * |<-----------compressed data--------->| + * |<-----------decompressed size------------------>| + * |<----margin---->| + * + * This technique is more useful for decompression, + * since decompressed size is typically larger, + * and margin is short. + * + * In-place decompression will work inside any buffer + * which size is >= LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize). + * This presumes that decompressedSize > compressedSize. + * Otherwise, it means compression actually expanded data, + * and it would be more efficient to store such data with a flag indicating it's not compressed. + * This can happen when data is not compressible (already compressed, or encrypted). + * + * For in-place compression, margin is larger, as it must be able to cope with both + * history preservation, requiring input data to remain unmodified up to LZ4_DISTANCE_MAX, + * and data expansion, which can happen when input is not compressible. + * As a consequence, buffer size requirements are much higher, + * and memory savings offered by in-place compression are more limited. + * + * There are ways to limit this cost for compression : + * - Reduce history size, by modifying LZ4_DISTANCE_MAX. + * Note that it is a compile-time constant, so all compressions will apply this limit. + * Lower values will reduce compression ratio, except when input_size < LZ4_DISTANCE_MAX, + * so it's a reasonable trick when inputs are known to be small. + * - Require the compressor to deliver a "maximum compressed size". + * This is the `dstCapacity` parameter in `LZ4_compress*()`. + * When this size is < LZ4_COMPRESSBOUND(inputSize), then compression can fail, + * in which case, the return code will be 0 (zero). + * The caller must be ready for these cases to happen, + * and typically design a backup scheme to send data uncompressed. + * The combination of both techniques can significantly reduce + * the amount of margin required for in-place compression. + * + * In-place compression can work in any buffer + * which size is >= (maxCompressedSize) + * with maxCompressedSize == LZ4_COMPRESSBOUND(srcSize) for guaranteed compression success. + * LZ4_COMPRESS_INPLACE_BUFFER_SIZE() depends on both maxCompressedSize and LZ4_DISTANCE_MAX, + * so it's possible to reduce memory requirements by playing with them. + */ + +#define LZ4_DECOMPRESS_INPLACE_MARGIN(compressedSize) (((compressedSize) >> 8) + 32) +#define LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize) ((decompressedSize) + LZ4_DECOMPRESS_INPLACE_MARGIN(decompressedSize)) /**< note: presumes that compressedSize < decompressedSize. note2: margin is overestimated a bit, since it could use compressedSize instead */ + +#ifndef LZ4_DISTANCE_MAX /* history window size; can be user-defined at compile time */ +# define LZ4_DISTANCE_MAX 65535 /* set to maximum value by default */ +#endif + +#define LZ4_COMPRESS_INPLACE_MARGIN (LZ4_DISTANCE_MAX + 32) /* LZ4_DISTANCE_MAX can be safely replaced by srcSize when it's smaller */ +#define LZ4_COMPRESS_INPLACE_BUFFER_SIZE(maxCompressedSize) ((maxCompressedSize) + LZ4_COMPRESS_INPLACE_MARGIN) /**< maxCompressedSize is generally LZ4_COMPRESSBOUND(inputSize), but can be set to any lower value, with the risk that compression can fail (return code 0(zero)) */ + +#endif /* LZ4_STATIC_3504398509 */ +#endif /* LZ4_STATIC_LINKING_ONLY */ + + + +#ifndef LZ4_H_98237428734687 +#define LZ4_H_98237428734687 + +/*-************************************************************ + * PRIVATE DEFINITIONS + ************************************************************** + * Do not use these definitions directly. + * They are only exposed to allow static allocation of `LZ4_stream_t` and `LZ4_streamDecode_t`. + * Accessing members will expose code to API and/or ABI break in future versions of the library. + **************************************************************/ +#define LZ4_HASHLOG (LZ4_MEMORY_USAGE-2) +#define LZ4_HASHTABLESIZE (1 << LZ4_MEMORY_USAGE) +#define LZ4_HASH_SIZE_U32 (1 << LZ4_HASHLOG) /* required as macro for static allocation */ + +#if defined(__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +#include + +typedef struct LZ4_stream_t_internal LZ4_stream_t_internal; +struct LZ4_stream_t_internal { + uint32_t hashTable[LZ4_HASH_SIZE_U32]; + uint32_t currentOffset; + uint16_t dirty; + uint16_t tableType; + const uint8_t* dictionary; + const LZ4_stream_t_internal* dictCtx; + uint32_t dictSize; +}; + +typedef struct { + const uint8_t* externalDict; + size_t extDictSize; + const uint8_t* prefixEnd; + size_t prefixSize; +} LZ4_streamDecode_t_internal; + +#else + +typedef struct LZ4_stream_t_internal LZ4_stream_t_internal; +struct LZ4_stream_t_internal { + unsigned int hashTable[LZ4_HASH_SIZE_U32]; + unsigned int currentOffset; + unsigned short dirty; + unsigned short tableType; + const unsigned char* dictionary; + const LZ4_stream_t_internal* dictCtx; + unsigned int dictSize; +}; + +typedef struct { + const unsigned char* externalDict; + const unsigned char* prefixEnd; + size_t extDictSize; + size_t prefixSize; +} LZ4_streamDecode_t_internal; + +#endif + +/*! LZ4_stream_t : + * information structure to track an LZ4 stream. + * LZ4_stream_t can also be created using LZ4_createStream(), which is recommended. + * The structure definition can be convenient for static allocation + * (on stack, or as part of larger structure). + * Init this structure with LZ4_initStream() before first use. + * note : only use this definition in association with static linking ! + * this definition is not API/ABI safe, and may change in a future version. + */ +#define LZ4_STREAMSIZE_U64 ((1 << (LZ4_MEMORY_USAGE-3)) + 4 + ((sizeof(void*)==16) ? 4 : 0) /*AS-400*/ ) +#define LZ4_STREAMSIZE (LZ4_STREAMSIZE_U64 * sizeof(unsigned long long)) +union LZ4_stream_u { + unsigned long long table[LZ4_STREAMSIZE_U64]; + LZ4_stream_t_internal internal_donotuse; +} ; /* previously typedef'd to LZ4_stream_t */ + +/*! LZ4_initStream() : v1.9.0+ + * An LZ4_stream_t structure must be initialized at least once. + * This is automatically done when invoking LZ4_createStream(), + * but it's not when the structure is simply declared on stack (for example). + * + * Use LZ4_initStream() to properly initialize a newly declared LZ4_stream_t. + * It can also initialize any arbitrary buffer of sufficient size, + * and will @return a pointer of proper type upon initialization. + * + * Note : initialization fails if size and alignment conditions are not respected. + * In which case, the function will @return NULL. + * Note2: An LZ4_stream_t structure guarantees correct alignment and size. + * Note3: Before v1.9.0, use LZ4_resetStream() instead + */ +LZ4LIB_API LZ4_stream_t* LZ4_initStream (void* buffer, size_t size); + + +/*! LZ4_streamDecode_t : + * information structure to track an LZ4 stream during decompression. + * init this structure using LZ4_setStreamDecode() before first use. + * note : only use in association with static linking ! + * this definition is not API/ABI safe, + * and may change in a future version ! + */ +#define LZ4_STREAMDECODESIZE_U64 (4 + ((sizeof(void*)==16) ? 2 : 0) /*AS-400*/ ) +#define LZ4_STREAMDECODESIZE (LZ4_STREAMDECODESIZE_U64 * sizeof(unsigned long long)) +union LZ4_streamDecode_u { + unsigned long long table[LZ4_STREAMDECODESIZE_U64]; + LZ4_streamDecode_t_internal internal_donotuse; +} ; /* previously typedef'd to LZ4_streamDecode_t */ + + + +/*-************************************ +* Obsolete Functions +**************************************/ + +/*! Deprecation warnings + * + * Deprecated functions make the compiler generate a warning when invoked. + * This is meant to invite users to update their source code. + * Should deprecation warnings be a problem, it is generally possible to disable them, + * typically with -Wno-deprecated-declarations for gcc + * or _CRT_SECURE_NO_WARNINGS in Visual. + * + * Another method is to define LZ4_DISABLE_DEPRECATE_WARNINGS + * before including the header file. + */ +#ifdef LZ4_DISABLE_DEPRECATE_WARNINGS +# define LZ4_DEPRECATED(message) /* disable deprecation warnings */ +#else +# define LZ4_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +# if defined (__cplusplus) && (__cplusplus >= 201402) /* C++14 or greater */ +# define LZ4_DEPRECATED(message) [[deprecated(message)]] +# elif (LZ4_GCC_VERSION >= 405) || defined(__clang__) +# define LZ4_DEPRECATED(message) __attribute__((deprecated(message))) +# elif (LZ4_GCC_VERSION >= 301) +# define LZ4_DEPRECATED(message) __attribute__((deprecated)) +# elif defined(_MSC_VER) +# define LZ4_DEPRECATED(message) __declspec(deprecated(message)) +# else +# pragma message("WARNING: You need to implement LZ4_DEPRECATED for this compiler") +# define LZ4_DEPRECATED(message) +# endif +#endif /* LZ4_DISABLE_DEPRECATE_WARNINGS */ + +/* Obsolete compression functions */ +LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress (const char* src, char* dest, int srcSize); +LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress_limitedOutput (const char* src, char* dest, int srcSize, int maxOutputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_withState (void* state, const char* source, char* dest, int inputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_limitedOutput_withState (void* state, const char* source, char* dest, int inputSize, int maxOutputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize, int maxOutputSize); + +/* Obsolete decompression functions */ +LZ4_DEPRECATED("use LZ4_decompress_fast() instead") LZ4LIB_API int LZ4_uncompress (const char* source, char* dest, int outputSize); +LZ4_DEPRECATED("use LZ4_decompress_safe() instead") LZ4LIB_API int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize); + +/* Obsolete streaming functions; degraded functionality; do not use! + * + * In order to perform streaming compression, these functions depended on data + * that is no longer tracked in the state. They have been preserved as well as + * possible: using them will still produce a correct output. However, they don't + * actually retain any history between compression calls. The compression ratio + * achieved will therefore be no better than compressing each chunk + * independently. + */ +LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API void* LZ4_create (char* inputBuffer); +LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API int LZ4_sizeofStreamState(void); +LZ4_DEPRECATED("Use LZ4_resetStream() instead") LZ4LIB_API int LZ4_resetStreamState(void* state, char* inputBuffer); +LZ4_DEPRECATED("Use LZ4_saveDict() instead") LZ4LIB_API char* LZ4_slideInputBuffer (void* state); + +/* Obsolete streaming decoding functions */ +LZ4_DEPRECATED("use LZ4_decompress_safe_usingDict() instead") LZ4LIB_API int LZ4_decompress_safe_withPrefix64k (const char* src, char* dst, int compressedSize, int maxDstSize); +LZ4_DEPRECATED("use LZ4_decompress_fast_usingDict() instead") LZ4LIB_API int LZ4_decompress_fast_withPrefix64k (const char* src, char* dst, int originalSize); + +/*! LZ4_decompress_fast() : **unsafe!** + * These functions used to be faster than LZ4_decompress_safe(), + * but it has changed, and they are now slower than LZ4_decompress_safe(). + * This is because LZ4_decompress_fast() doesn't know the input size, + * and therefore must progress more cautiously in the input buffer to not read beyond the end of block. + * On top of that `LZ4_decompress_fast()` is not protected vs malformed or malicious inputs, making it a security liability. + * As a consequence, LZ4_decompress_fast() is strongly discouraged, and deprecated. + * + * The last remaining LZ4_decompress_fast() specificity is that + * it can decompress a block without knowing its compressed size. + * Such functionality could be achieved in a more secure manner, + * by also providing the maximum size of input buffer, + * but it would require new prototypes, and adaptation of the implementation to this new use case. + * + * Parameters: + * originalSize : is the uncompressed size to regenerate. + * `dst` must be already allocated, its size must be >= 'originalSize' bytes. + * @return : number of bytes read from source buffer (== compressed size). + * The function expects to finish at block's end exactly. + * If the source stream is detected malformed, the function stops decoding and returns a negative result. + * note : LZ4_decompress_fast*() requires originalSize. Thanks to this information, it never writes past the output buffer. + * However, since it doesn't know its 'src' size, it may read an unknown amount of input, past input buffer bounds. + * Also, since match offsets are not validated, match reads from 'src' may underflow too. + * These issues never happen if input (compressed) data is correct. + * But they may happen if input data is invalid (error or intentional tampering). + * As a consequence, use these functions in trusted environments with trusted data **only**. + */ + +LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe() instead") +LZ4LIB_API int LZ4_decompress_fast (const char* src, char* dst, int originalSize); +LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_continue() instead") +LZ4LIB_API int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int originalSize); +LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_usingDict() instead") +LZ4LIB_API int LZ4_decompress_fast_usingDict (const char* src, char* dst, int originalSize, const char* dictStart, int dictSize); + +/*! LZ4_resetStream() : + * An LZ4_stream_t structure must be initialized at least once. + * This is done with LZ4_initStream(), or LZ4_resetStream(). + * Consider switching to LZ4_initStream(), + * invoking LZ4_resetStream() will trigger deprecation warnings in the future. + */ +LZ4LIB_API void LZ4_resetStream (LZ4_stream_t* streamPtr); + + +#endif /* LZ4_H_98237428734687 */ + + +#if defined (__cplusplus) +} +#endif \ No newline at end of file diff --git a/include/dataset_creation/dataset_manipulation/bytes_to_file.h b/include/dataset_creation/dataset_manipulation/bytes_to_file.h new file mode 100644 index 0000000000000000000000000000000000000000..44b934c2fc2d55825f66593f73a811c481ce43d4 --- /dev/null +++ b/include/dataset_creation/dataset_manipulation/bytes_to_file.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include "../../../libraries/protobuf/build/midi.pb.h" +#include "../compression/lz4.h" + + +namespace dataset_manipulation { + class BytesToFile{ + private: + std::string filepath_; + std::string header_filepath_; + std::fstream file_stream_; + std::fstream header_file_stream_; + midi::Dataset dataset_split_protobuf_; + int flush_count_; + bool can_write; + + public: + BytesToFile(std::string external_filepath_); + void enableWrite(); + void appendBytesToFileStream(std::string& bytes_as_string, size_t split_id); + void writeFile(); + void close(); + }; +} \ No newline at end of file diff --git a/libraries/midifile b/libraries/midifile new file mode 160000 index 0000000000000000000000000000000000000000..838c62c4a13245ced8e13a84e6c2a1994664acd5 --- /dev/null +++ b/libraries/midifile @@ -0,0 +1 @@ +Subproject commit 838c62c4a13245ced8e13a84e6c2a1994664acd5 diff --git a/libraries/protobuf/CMakeLists.txt b/libraries/protobuf/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f3ce8685a274028f05df768ca4a4bd1849ff8e0 --- /dev/null +++ b/libraries/protobuf/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.8) + +project(midigpt_proto) + +set(PROTO_DEF + src/enum.proto + src/midi.proto + src/midi_internal.proto + src/track_type.proto + src/feature_extraction.proto) + +find_package(Protobuf REQUIRED) + +protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS + ${PROTO_DEF} + PROTOC_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} # it's the default but it does not hurt to be explicit here... +) + +add_library(midigpt_proto + ${PROTO_SRCS} + ${PROTO_HDRS}) + +target_include_directories(midigpt_proto +PUBLIC + ${Protobuf_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR} # for generated protobuf files +) +target_link_libraries(midigpt_proto ${Protobuf_LIBRARIES}) \ No newline at end of file diff --git a/libraries/protobuf/CMakeSettings.json b/libraries/protobuf/CMakeSettings.json new file mode 100644 index 0000000000000000000000000000000000000000..2a7e537e09316df66b12f427deeae89a75068dcd --- /dev/null +++ b/libraries/protobuf/CMakeSettings.json @@ -0,0 +1,29 @@ +{ + "configurations": [ + { + "name": "x64-Debug", + "generator": "Ninja", + "configurationType": "Debug", + "inheritEnvironments": [ "msvc_x64_x64" ], + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "" + }, + { + "name": "WSL-GCC-Debug", + "generator": "Ninja", + "configurationType": "Debug", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeExecutable": "cmake", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "inheritEnvironments": [ "linux_x64" ], + "wslPath": "${defaultWSLPath}", + "variables": [] + } + ] +} \ No newline at end of file diff --git a/libraries/protobuf/include/proto_library.h b/libraries/protobuf/include/proto_library.h new file mode 100644 index 0000000000000000000000000000000000000000..6ccbbd194f9e5c310fa7a3012c7e657236d2a2c0 --- /dev/null +++ b/libraries/protobuf/include/proto_library.h @@ -0,0 +1,5 @@ +#pragma once + +namespace midigptProto { + void testMmmProto(); +} \ No newline at end of file diff --git a/libraries/protobuf/src/enum.proto b/libraries/protobuf/src/enum.proto new file mode 100644 index 0000000000000000000000000000000000000000..9100a86e51161e82a145c33d3c5182ba6fe3fb56 --- /dev/null +++ b/libraries/protobuf/src/enum.proto @@ -0,0 +1,389 @@ +syntax = "proto2"; + +package midi; + +enum TOKEN_TYPE { + TOKEN_PIECE_START = 0; + TOKEN_NOTE_ONSET = 1; + TOKEN_PITCH = 3; + TOKEN_VELOCITY = 5; + TOKEN_TIME_ABSOLUTE_POS = 7; + TOKEN_INSTRUMENT = 8; + TOKEN_BAR = 9; + TOKEN_BAR_END = 10; + TOKEN_TRACK = 11; + TOKEN_TRACK_END = 12; + TOKEN_DRUM_TRACK = 13; + TOKEN_FILL_IN = 14; + TOKEN_FILL_IN_PLACEHOLDER = 15; + TOKEN_FILL_IN_START = 16; + TOKEN_FILL_IN_END = 17; + TOKEN_VELOCITY_LEVEL = 19; + TOKEN_GENRE = 20; + TOKEN_DENSITY_LEVEL = 21; + TOKEN_TIME_SIGNATURE = 22; + TOKEN_NOTE_DURATION = 26; + TOKEN_AV_POLYPHONY = 27; + TOKEN_MIN_POLYPHONY = 28; + TOKEN_MAX_POLYPHONY = 29; + TOKEN_MIN_NOTE_DURATION = 30; + TOKEN_MAX_NOTE_DURATION = 31; + TOKEN_NUM_BARS = 32; + TOKEN_MIN_POLYPHONY_HARD = 33; + TOKEN_MAX_POLYPHONY_HARD = 34; + TOKEN_MIN_NOTE_DURATION_HARD = 35; + TOKEN_MAX_NOTE_DURATION_HARD = 36; + TOKEN_BAR_LEVEL_ONSET_DENSITY = 40; + TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MIN = 41; + TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MAX = 42; + TOKEN_TRACK_LEVEL_ONSET_DENSITY = 43; + TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MIN = 44; + TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MAX = 45; + TOKEN_TRACK_LEVEL_ONSET_DENSITY_MIN = 46; + TOKEN_TRACK_LEVEL_ONSET_DENSITY_MAX = 47; + TOKEN_TRACK_LEVEL_PITCH_RANGE_MIN = 48; + TOKEN_TRACK_LEVEL_PITCH_RANGE_MAX = 49; + TOKEN_CONTAINS_NOTE_DURATION_THIRTY_SECOND = 59; + TOKEN_CONTAINS_NOTE_DURATION_SIXTEENTH = 60; + TOKEN_CONTAINS_NOTE_DURATION_EIGHTH = 61; + TOKEN_CONTAINS_NOTE_DURATION_QUARTER = 62; + TOKEN_CONTAINS_NOTE_DURATION_HALF = 63; + TOKEN_CONTAINS_NOTE_DURATION_WHOLE = 64; + TOKEN_DELTA = 67; + TOKEN_DELTA_DIRECTION = 68; + TOKEN_NONE = 72; +} + +enum ATTRIBUTE_CONTROL_TYPE { + ATTRIBUTE_CONTROL_NOTE_DENSITY = 0; + ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_POLYPHONY = 1; + ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_DENSITY = 2; + ATTRIBUTE_CONTROL_PITCH_RANGE = 3; + ATTRIBUTE_CONTROL_GENRE = 4; + ATTRIBUTE_CONTROL_POLYPHONY_QUANTILE = 5; + ATTRIBUTE_CONTROL_NOTE_DURATION_QUANTILE = 6; + ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_DENSITY = 7; + ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_POLYPHONY = 8; + ATTRIBUTE_CONTROL_TRACK_LEVEL_NOTE_DURATION = 9; + ATTRIBUTE_CONTROL_END = 10; +} + +enum GenreMusicmap { + GENRE_MUSICMAP_ANY = 0; + GENRE_MUSICMAP_ALTERNATIVE_ROCK = 1; + GENRE_MUSICMAP_AMBIENT = 2; + GENRE_MUSICMAP_BLUES = 3; + GENRE_MUSICMAP_BREAKBEAT = 4; + GENRE_MUSICMAP_CLASSICAL = 5; + GENRE_MUSICMAP_CLASSIC_ROCK = 6; + GENRE_MUSICMAP_CONTEMPORARY_ROCK = 7; + GENRE_MUSICMAP_COUNTRY = 8; + GENRE_MUSICMAP_DRUM_N_BASS = 9; + GENRE_MUSICMAP_FOLK = 10; + GENRE_MUSICMAP_GOSPEL = 11; + GENRE_MUSICMAP_HARDCORE_PUNK = 12; + GENRE_MUSICMAP_HARDCORE_TECHNO = 13; + GENRE_MUSICMAP_HEAVY_METAL = 14; + GENRE_MUSICMAP_HIP_HOP = 15; + GENRE_MUSICMAP_HOUSE = 16; + GENRE_MUSICMAP_INDUSTRIAL = 17; + GENRE_MUSICMAP_JAZZ = 18; + GENRE_MUSICMAP_LATIN = 19; + GENRE_MUSICMAP_POP = 20; + GENRE_MUSICMAP_PUNK = 21; + GENRE_MUSICMAP_PUNK_ROCK = 22; + GENRE_MUSICMAP_RANDB = 23; + GENRE_MUSICMAP_REGGAE = 24; + GENRE_MUSICMAP_ROCK_N_ROLL = 25; + GENRE_MUSICMAP_TECHNO = 26; + GENRE_MUSICMAP_TRANCE = 27; + GENRE_MUSICMAP_UTILITY = 28; + GENRE_MUSICMAP_WORLD = 29; + GENRE_MUSICMAP_NONE = 30; +}; + +enum GM_CATEGORY { + GM_CATEGORY_MONO = 0; + GM_CATEGORY_POLY = 1; + GM_CATEGORY_SOUND_FX = 2; + GM_CATEGORY_PERC = 3; +}; + +enum GM_TYPE { + any = 0; + piano = 1; + chromatic_perc = 2; + organ = 3; + guitar = 4; + bass = 5; + strings = 6; + ensemble = 7; + brass = 8; + reed = 9; + pipe = 10; + synth_lead = 11; + synth_pad = 12; + synth_effects = 13; + ethnic = 14; + percussive = 15; + sound_fx = 16; + no_drums = 17; + drums = 18; + acoustic_grand_piano = 19; + bright_acoustic_piano = 20; + electric_grand_piano = 21; + honky_tonk_piano = 22; + electric_piano_1 = 23; + electric_piano_2 = 24; + harpsichord = 25; + clavi = 26; + celesta = 27; + glockenspiel = 28; + music_box = 29; + vibraphone = 30; + marimba = 31; + xylophone = 32; + tubular_bells = 33; + dulcimer = 34; + drawbar_organ = 35; + percussive_organ = 36; + rock_organ = 37; + church_organ = 38; + reed_organ = 39; + accordion = 40; + harmonica = 41; + tango_accordion = 42; + acoustic_guitar_nylon = 43; + acoustic_guitar_steel = 44; + electric_guitar_jazz = 45; + electric_guitar_clean = 46; + electric_guitar_muted = 47; + overdriven_guitar = 48; + distortion_guitar = 49; + guitar_harmonics = 50; + acoustic_bass = 51; + electric_bass_finger = 52; + electric_bass_pick = 53; + fretless_bass = 54; + slap_bass_1 = 55; + slap_bass_2 = 56; + synth_bass_1 = 57; + synth_bass_2 = 58; + violin = 59; + viola = 60; + cello = 61; + contrabass = 62; + tremolo_strings = 63; + pizzicato_strings = 64; + orchestral_harp = 65; + timpani = 66; + string_ensemble_1 = 67; + string_ensemble_2 = 68; + synth_strings_1 = 69; + synth_strings_2 = 70; + choir_aahs = 71; + voice_oohs = 72; + synth_voice = 73; + orchestra_hit = 74; + trumpet = 75; + trombone = 76; + tuba = 77; + muted_trumpet = 78; + french_horn = 79; + brass_section = 80; + synth_brass_1 = 81; + synth_brass_2 = 82; + soprano_sax = 83; + alto_sax = 84; + tenor_sax = 85; + baritone_sax = 86; + oboe = 87; + english_horn = 88; + bassoon = 89; + clarinet = 90; + piccolo = 91; + flute = 92; + recorder = 93; + pan_flute = 94; + blown_bottle = 95; + shakuhachi = 96; + whistle = 97; + ocarina = 98; + lead_1_square = 99; + lead_2_sawtooth = 100; + lead_3_calliope = 101; + lead_4_chiff = 102; + lead_5_charang = 103; + lead_6_voice = 104; + lead_7_fifths = 105; + lead_8_bass__lead = 106; + pad_1_new_age = 107; + pad_2_warm = 108; + pad_3_polysynth = 109; + pad_4_choir = 110; + pad_5_bowed = 111; + pad_6_metallic = 112; + pad_7_halo = 113; + pad_8_sweep = 114; + fx_1_rain = 115; + fx_2_soundtrack = 116; + fx_3_crystal = 117; + fx_4_atmosphere = 118; + fx_5_brightness = 119; + fx_6_goblins = 120; + fx_7_echoes = 121; + fx_8_sci_fi = 122; + sitar = 123; + banjo = 124; + shamisen = 125; + koto = 126; + kalimba = 127; + bag_pipe = 128; + fiddle = 129; + shanai = 130; + tinkle_bell = 131; + agogo = 132; + steel_drums = 133; + woodblock = 134; + taiko_drum = 135; + melodic_tom = 136; + synth_drum = 137; + reverse_cymbal = 138; + guitar_fret_noise = 139; + breath_noise = 140; + seashore = 141; + bird_tweet = 142; + telephone_ring = 143; + helicopter = 144; + applause = 145; + gunshot = 146; + drum_0 = 147; + drum_1 = 148; + drum_2 = 149; + drum_3 = 150; + drum_4 = 151; + drum_5 = 152; + drum_6 = 153; + drum_7 = 154; + drum_8 = 155; + drum_9 = 156; + drum_10 = 157; + drum_11 = 158; + drum_12 = 159; + drum_13 = 160; + drum_14 = 161; + drum_15 = 162; + drum_16 = 163; + drum_17 = 164; + drum_18 = 165; + drum_19 = 166; + drum_20 = 167; + drum_21 = 168; + drum_22 = 169; + drum_23 = 170; + drum_24 = 171; + drum_25 = 172; + drum_26 = 173; + drum_27 = 174; + drum_28 = 175; + drum_29 = 176; + drum_30 = 177; + drum_31 = 178; + drum_32 = 179; + drum_33 = 180; + drum_34 = 181; + drum_35 = 182; + drum_36 = 183; + drum_37 = 184; + drum_38 = 185; + drum_39 = 186; + drum_40 = 187; + drum_41 = 188; + drum_42 = 189; + drum_43 = 190; + drum_44 = 191; + drum_45 = 192; + drum_46 = 193; + drum_47 = 194; + drum_48 = 195; + drum_49 = 196; + drum_50 = 197; + drum_51 = 198; + drum_52 = 199; + drum_53 = 200; + drum_54 = 201; + drum_55 = 202; + drum_56 = 203; + drum_57 = 204; + drum_58 = 205; + drum_59 = 206; + drum_60 = 207; + drum_61 = 208; + drum_62 = 209; + drum_63 = 210; + drum_64 = 211; + drum_65 = 212; + drum_66 = 213; + drum_67 = 214; + drum_68 = 215; + drum_69 = 216; + drum_70 = 217; + drum_71 = 218; + drum_72 = 219; + drum_73 = 220; + drum_74 = 221; + drum_75 = 222; + drum_76 = 223; + drum_77 = 224; + drum_78 = 225; + drum_79 = 226; + drum_80 = 227; + drum_81 = 228; + drum_82 = 229; + drum_83 = 230; + drum_84 = 231; + drum_85 = 232; + drum_86 = 233; + drum_87 = 234; + drum_88 = 235; + drum_89 = 236; + drum_90 = 237; + drum_91 = 238; + drum_92 = 239; + drum_93 = 240; + drum_94 = 241; + drum_95 = 242; + drum_96 = 243; + drum_97 = 244; + drum_98 = 245; + drum_99 = 246; + drum_100 = 247; + drum_101 = 248; + drum_102 = 249; + drum_103 = 250; + drum_104 = 251; + drum_105 = 252; + drum_106 = 253; + drum_107 = 254; + drum_108 = 255; + drum_109 = 256; + drum_110 = 257; + drum_111 = 258; + drum_112 = 259; + drum_113 = 260; + drum_114 = 261; + drum_115 = 262; + drum_116 = 263; + drum_117 = 264; + drum_118 = 265; + drum_119 = 266; + drum_120 = 267; + drum_121 = 268; + drum_122 = 269; + drum_123 = 270; + drum_124 = 271; + drum_125 = 272; + drum_126 = 273; + drum_127 = 274; +}; + diff --git a/libraries/protobuf/src/feature_extraction.proto b/libraries/protobuf/src/feature_extraction.proto new file mode 100644 index 0000000000000000000000000000000000000000..5d5859c05fa6df929b25744d6830a333d12cb82c --- /dev/null +++ b/libraries/protobuf/src/feature_extraction.proto @@ -0,0 +1,92 @@ +syntax = "proto2"; + +package midi; + +message PitchRange { + optional int32 instrument = 1; + optional int32 min = 2; + optional int32 max = 3; +} + +message MetricDepth { + optional string filepath = 1; + optional int32 track_num = 2; + optional int32 instrument = 3; + optional bool is_drum = 4; + optional bool has_time_signatures = 5; + repeated int32 metric_depth = 6; + optional int32 tpq = 7; +} + +message MedianMetricDepth { + optional string filepath = 1; + optional int32 track_num = 2; + optional int32 instrument = 3; + optional bool is_drum = 4; + optional bool has_time_signatures = 5; + optional int32 median_metric_depth = 6; + optional int32 tpq = 7; +} + +message MostFrequentMetricDepth { + optional string filepath = 1; + optional int32 track_num = 2; + optional int32 instrument = 3; + optional bool is_drum = 4; + optional bool has_time_signatures = 5; + optional int32 most_frequent_metric_depth = 6; + optional int32 tpq = 7; +} + +message DownbeatProportion { + optional string filepath = 1; + optional int32 track_num = 2; + optional int32 instrument = 3; + optional bool is_drum = 4; + optional float downbeat_proportion = 5; +} + +message AlignedMetricDepth { + optional string filepath = 1; + optional int32 track_num = 2; + optional int32 instrument = 3; + optional bool is_drum = 4; + + optional int32 aligned_offset = 5; +} + +message SimultaneousOnset { + optional string filepath = 1; + optional int32 track_num = 2; + optional int32 instrument = 3; + optional bool is_drum = 4; + + optional int32 simultaneous_onset_count = 5; +} + +message BeatStability { + optional string filepath = 1; + //optional int32 track_num = 2; + //optional int32 instrument = 3; + //optional bool is_drum = 4; + + optional float beat_stability_stdev = 5; + optional float beat_stability_median = 6; +} + +message DrumPresence { + optional string filepath = 1; + optional float drum_presence = 2; +} + +message Features { + repeated PitchRange pitch_range = 1; + repeated MetricDepth metric_depth = 2; + repeated DownbeatProportion downbeat_proportion = 3; + repeated AlignedMetricDepth aligned_metric_depth = 4; + repeated SimultaneousOnset simultaneous_onset = 5; + repeated BeatStability beat_stability = 6; + repeated DrumPresence drum_presence = 7; + repeated MedianMetricDepth median_metric_depth = 8; + repeated MostFrequentMetricDepth most_frequent_metric_depth = 9; +} \ No newline at end of file diff --git a/libraries/protobuf/src/midi.proto b/libraries/protobuf/src/midi.proto new file mode 100644 index 0000000000000000000000000000000000000000..a80eedb68b8b7a76fef9d5a1e69eb0336dc18293 --- /dev/null +++ b/libraries/protobuf/src/midi.proto @@ -0,0 +1,429 @@ +/* + +# Introduction + +To generate material three protobuf messages must be supplied to the midigpt::sample function. The specification for each of these messages (Piece, Status and SampleParam) is outlined in this document. Any field which starts with internal_ should be ignored, as these are for internal use only. Or if you really know what you are doing ;). There are examples of midi::Piece, midi::Status and midi::SampleParam objects in the docs folder. For working examples, please consult the testing suite, which is found in MMM_API/src/midigpt_api/test/unit.cpp. + +# Functionality Overview + +| Functionality | Scope | Description | +| ----------- | ----------- | ----------- | +| Velocity | Always Enabled | 32 levels of loudness for individual notes. | +| Instrument | Track | The General MIDI instrument (i.e. Timbre). | +| Max Polyphony Hard-Limit | Track | A hard-limit on the number of simultaneously sounding notes. | +| Note Duration (upper/lower soft bounds) | Non-Drum Tracks | Tells the model what the 15th (lower) and 85th (upper) quantiles of the note duration (i.e. Quarter, Whole) distrbution should be. | +| Polyphony (upper/lower soft bounds) | Non-Drum Tracks | Tells the model what the 15th (lower) and 85th (upper) quantiles of the polyphony distrbution should be. | +| Density (10 levels) | Drum Tracks | Tells the model the number of notes per bar to produce | +| Auto-regressive Sampling Mode | Track | When enabled, bars are always sampled in chronological order. | +| Time Signature | Bar | A unique time-signature can be specified for each bar. | +| Temperature | per API call | A higher value increases entropy of generated output. Temperature=1 applies no modification to the probabilities produced by the model. | +| Context size (model_dim) | per API call | The number of bars that the model can process in one API call | + + +# Parameter Constraints and Considerations + +There are two sampling methods: autoregressive generation, where we progressively sample musical material forwards in time on each track; and conditional generation (bar-infilling), where generated material is conditioned on past and future material. + +Note that a single call the the model midigpt:sample() may involve both autoregressive and conditional generation, as these can be specified on a per-track basis. These constraints are + +## Sample Param Constraints + +1. tracks_per_step : + - must be on range [1,number of tracks in piece] + +2. bars_per_step : + - must be on the range [1,model_dim] + - for conditional generation it is ill-advised for the user to have bars_per_step == model_dim, as this means generation will not be conditioned on any bars + +3. shuffle : + - this only applies in cases where one or more tracks are conditionally generated (i.e. resample = False && 1+ selected_bars = True) + +4. percentage : + - this only applies in cases where one or more tracks are conditionally generated (i.e. resample = False && 1+ selected_bars = True) + +## Status Constraints + +1. density : + - this control only applies to drum tracks. This works will both infilling and autoregressive mode. + +2. note duration / polyphony : + - this control only applies to non-drum tracks. This works will both infilling and autoregressive mode. + +3. autoregressive : + - you can only enable autoregressive mode (resample = True) when all the bars are selected in a track. + - note you may have autoregressive disabled when all bars are selected in a track + +4. ignore : + - bars which have 1+ selected_bars = True may not be ignored, as they are needed to condition the generation + + +# Protobuf Specification + +*/ +syntax = "proto2"; + +import "track_type.proto"; +import "enum.proto"; +import "midi_internal.proto"; +import "feature_extraction.proto"; + +package midi; + +/* +Specify the minimum or maximum amount of polyphony using these values. Using POLYPHONY_ANY lets the model choose the level of polyphony. +*/ +enum PolyphonyLevel { + POLYPHONY_ANY = 0; + POLYPHONY_ONE = 1; + POLYPHONY_TWO = 2; + POLYPHONY_THREE = 3; + POLYPHONY_FOUR = 4; + POLYPHONY_FIVE = 5; + POLYPHONY_SIX = 6; +} + +/* +Specify the minimum or maximum bounds for note-duration using these values. Using DURATION_ANY lets the model choose the bounds for note duration. +*/ +enum NoteDurationLevel { + DURATION_ANY = 0; + DURATION_THIRTY_SECOND = 1; + DURATION_SIXTEENTH = 2; + DURATION_EIGHTH = 3; + DURATION_QUARTER = 4; + DURATION_HALF = 5; + DURATION_WHOLE = 6; +} + +/* +Specify the minimum or maximum amount of note density using these values. Using DENSITY ANY lets the model choose the level of density. +*/ +enum DensityLevel { + DENSITY_ANY = 0; + DENSITY_ONE = 1; + DENSITY_TWO = 2; + DENSITY_THREE = 3; + DENSITY_FOUR = 4; + DENSITY_FIVE = 5; + DENSITY_SIX = 6; + DENSITY_SEVEN = 7; + DENSITY_EIGHT = 8; + DENSITY_NINE = 9; + DENSITY_TEN = 10; +} + +// specify the levels for bar level onset density +enum BarLevelOnsetDensityLevel { + BAR_LEVEL_ONSET_DENSITY_ANY = 0; + BAR_LEVEL_ONSET_DENSITY_ZERO = 1; + BAR_LEVEL_ONSET_DENSITY_ONE = 2; + BAR_LEVEL_ONSET_DENSITY_TWO = 3; + BAR_LEVEL_ONSET_DENSITY_THREE = 4; + BAR_LEVEL_ONSET_DENSITY_FOUR = 5; + BAR_LEVEL_ONSET_DENSITY_FIVE = 6; + BAR_LEVEL_ONSET_DENSITY_SIX = 7; + BAR_LEVEL_ONSET_DENSITY_SEVEN = 8; + BAR_LEVEL_ONSET_DENSITY_EIGHT = 9; + BAR_LEVEL_ONSET_DENSITY_NINE = 10; + BAR_LEVEL_ONSET_DENSITY_TEN = 11; + BAR_LEVEL_ONSET_DENSITY_ELEVEN = 12; + BAR_LEVEL_ONSET_DENSITY_TWELVE = 13; + BAR_LEVEL_ONSET_DENSITY_THIRTEEN = 14; + BAR_LEVEL_ONSET_DENSITY_FOURTEEN = 15; + BAR_LEVEL_ONSET_DENSITY_FIFTEEN = 16; + BAR_LEVEL_ONSET_DENSITY_SIXTEEN = 17; +} + +// specify the levels for bar level onset polyphony +enum BarLevelOnsetPolyphonyLevel { + BAR_LEVEL_ONSET_POLYPHONY_ANY = 0; + BAR_LEVEL_ONSET_POLYPHONY_ONE = 1; + BAR_LEVEL_ONSET_POLYPHONY_TWO = 2; + BAR_LEVEL_ONSET_POLYPHONY_THREE = 3; + BAR_LEVEL_ONSET_POLYPHONY_FOUR = 4; + BAR_LEVEL_ONSET_POLYPHONY_FIVE = 5; + BAR_LEVEL_ONSET_POLYPHONY_SIX = 6; +} + +enum SilenceProportionLevel { + SILENCE_PROPORTION_LEVEL_ANY = 0; + SILENCE_PROPORTION_LEVEL_ONE = 1; + SILENCE_PROPORTION_LEVEL_TWO = 2; + SILENCE_PROPORTION_LEVEL_THREE = 3; + SILENCE_PROPORTION_LEVEL_FOUR = 4; + SILENCE_PROPORTION_LEVEL_FIVE = 5; + SILENCE_PROPORTION_LEVEL_SIX = 6; + SILENCE_PROPORTION_LEVEL_SEVEN = 7; + SILENCE_PROPORTION_LEVEL_EIGHT = 8; + SILENCE_PROPORTION_LEVEL_NINE = 9; + SILENCE_PROPORTION_LEVEL_TEN = 10; +} + +enum BooleanLevel { + BOOLEAN_ANY = 0; + BOOLEAN_FALSE = 1; + BOOLEAN_TRUE = 2; +} + +/* +The Event Message is used to represent a MIDI note onset or offset. +*/ +message Event { + /* + The time of the event (either a note onset or note offset) relative to the current bar in quantized steps. Currently, most model quantize each quarter note beat into 12 subdivisions. As a result, if the event happens an eighth note after the start of the bar, this value would be 6. If the event occurs three quarter notes after the start of the bar, this value would be 3 * 12 = 36. + */ + optional int32 time = 1 [(minval) = 0, (maxval) = 1000000]; + /* + The MIDI velocity. This value must be 0 for note off messages. + */ + optional int32 velocity = 2 [(minval) = 0, (maxval) = 127]; + /* + The MIDI pitch value of on the range [0,128). + */ + optional int32 pitch = 3 [(minval) = 0, (maxval) = 127]; + + optional int32 internal_instrument = 4; + optional int32 internal_track_type = 10; + optional int32 internal_duration = 11; + optional int32 delta = 12; +} + +/* +The Bar message specifies the events occuring in a bar. +*/ +message Bar { + /* + A list of integers, which are simply the indices of the messages found in the Piece.events repeated message. Note offsets which occur at the end of the bar (i.e. event.time = 48 with a time signature of 4/4 and piece.resolution of 12) should be included in the current bar rather than the next bar. In other words, no note offsets should even have an event.time = 0, as these note offset events would belong in the previous bar. + */ + repeated int32 events = 1 [(minval) = 0, (maxval) = 2147483647]; + /* + Numerator for the time-signature of the bar. Note that while time signatures can vary from bar to bar, they cannot vary from track to track. In other words if the second bar in track 0 has a time signature of 4/4, the second bar in track 1 must also have a time signature of 4/4. + */ + optional int32 ts_numerator = 8 [(minval) = 1, (maxval) = 1000000]; + /* + Denominator for the time-signature of the bar. + */ + optional int32 ts_denominator = 9 [(minval) = 1, (maxval) = 1000000]; + + optional float internal_beat_length = 5; + optional bool internal_has_notes = 3; + repeated ContinuousFeature internal_feature = 10; + repeated BarFeatures internal_features = 11; // why isn't this just called bar features +} + +/* +The piece message contains a list of bars, and specifies the instrument and track_type. +*/ +message Track { + /* + A list of bars. Note that each track must have the same number of bars. + */ + repeated Bar bars = 1; + /* + The MIDI instrument number for the track. + */ + optional int32 instrument = 3 [(minval) = 0, (maxval) = 139]; + /*127 original instruments with drum instrument seperated into 12 individual drum intruments (TR-808) with the rest of the drums in a single intrument*/ + /* + This must be a value in the TRACK_TYPE enum. In most cases, using STANDARAD_TRACK and STANDARD_DRUM_TRACK will suffice to denote a non-drum instrument track and a drum track respectively. + */ + optional TRACK_TYPE track_type = 5; + + repeated TRACK_TYPE internal_train_types = 6; + repeated TrackFeatures internal_features = 7; // why isn't this called track features +} + +/* +The Piece message specifies the actual musical material in a track-separated event-based format, specifying the note onsets and offsets for each bar in each track. +*/ +message Piece { + /* + Organizes MIDI events into tracks and bars. In short, each track contains a list of bars, which in turn contains a list of event indices (corresponding to the repeated events message in the Piece. + */ + repeated Track tracks = 1; + /* + A list of MIDI events which the tracks and bars reference + */ + repeated Event events = 2; + /* + The time resolution used to quantize / discretize musical material. Unless otherwise instructed, this should be set to 12. + */ + optional int32 resolution = 3 [(minval) = 1, (maxval) = 24]; + /* + Optionally the tempo can be specified. However this is not taken into consideration by the model. + */ + //optional int32 tempo = 4 [(minval) = 1, (maxval) = 1000000]; + optional int32 tempo = 4; + + repeated int32 internal_valid_segments = 7; + repeated uint32 internal_valid_tracks = 8; + optional int32 internal_segment_length = 12; + repeated ValidTrack internal_valid_tracks_v2 = 13; + repeated GenreData internal_genre_data = 14; + optional MetadataLabels internal_metadata_labels = 15; + repeated PieceFeatures internal_features = 16; + + optional int32 internal_ticks_per_quarter = 5; + optional bool internal_has_time_signatures = 6; +} + +/* +The StatusBar message specifies per-bar information for generation. +*/ +message StatusBar { + optional int32 ts_numerator = 1; + optional int32 ts_denominator = 2; + optional BarLevelOnsetDensityLevel onset_density = 3; + optional BarLevelOnsetPolyphonyLevel onset_polyphony_min = 4; + optional BarLevelOnsetPolyphonyLevel onset_polyphony_max = 5; +} + +/* +The StatusTrack message specifies per-track information for generation. +*/ +message StatusTrack { + /* + The index of a track in the Piece message. For a track to be seen by the model, it must be referenced by a StatusTrack message via the track_id field. Tracks that are not referenced by a StatusTrack message will not be considered by the model. + */ + optional int32 track_id = 1 [(minval) = 0, (maxval) = 1000000]; + /* + This must be a value in the TRACK_TYPE enum. This should be equivalent to the TRACK_TYPE specified in the corresponding Track, unless you are giving the model an option to choose either a drum or instrument track. In this case use the STANDARAD_BOTH value here, and the TRACK_TYPE in the piece will be ignored. + */ + optional TRACK_TYPE track_type = 2; + /* + This must be a value in the GM_TYPE enum. It specifies the set of possible instruments that the model may choose from. The mapping between GM_TYPE and instrument numbers can be found in src/midigpt_api/enum/gm.h. For example, using midi::GM_TYPE::piano will allow the model to use any piano instrument. + */ + optional GM_TYPE instrument = 3; + /* + A list of boolean values which specifies whether a bar is to be generated (true) or conditioned on (false). This must be the same length as the number of bars in the corresponding Track message. + */ + repeated bool selected_bars = 5; + /* + Indicates whether or not to use auto-regressive sampling. Note that you can only use auto-regressive sampling when each value in selected_bars is true (i.e. the entire track is being generated). Note that you do not have to use auto-regressive sampling when all selected bars is all true. + */ + optional bool autoregressive = 6; + /* + This indicates that the track should be ignored. The model will not be conditioned on this track, and it will in no way effect the generated outcome. + */ + optional bool ignore = 7; + + optional DensityLevel density = 4; + optional PolyphonyLevel min_polyphony_q = 10; + optional PolyphonyLevel max_polyphony_q = 11; + optional NoteDurationLevel min_note_duration_q = 12; + optional NoteDurationLevel max_note_duration_q = 13; + + optional BarLevelOnsetPolyphonyLevel onset_polyphony_min = 20; + optional BarLevelOnsetPolyphonyLevel onset_polyphony_max = 21; + optional BarLevelOnsetDensityLevel onset_density = 22; + optional BarLevelOnsetDensityLevel onset_density_min = 23; + optional BarLevelOnsetDensityLevel onset_density_max = 24; + optional int32 min_pitch = 25 [(minval) = 0, (maxval) = 127]; + optional int32 max_pitch = 26 [(minval) = 0, (maxval) = 127]; + optional GenreMusicmap genre = 28; + optional SilenceProportionLevel silence_proportion_min = 29; + optional SilenceProportionLevel silence_proportion_max = 30; + optional DensityLevel note_density_level = 31; + + optional BooleanLevel contains_note_duration_thirty_second = 36; + optional BooleanLevel contains_note_duration_sixteenth = 37; + optional BooleanLevel contains_note_duration_eighth = 38; + optional BooleanLevel contains_note_duration_quarter = 39; + optional BooleanLevel contains_note_duration_half = 40; + optional BooleanLevel contains_note_duration_whole = 41; + + /* + Sets a hard limit on the polyphony on this track. This is implemented by keeping a record of all the currently sounding notes, and preventing the model from generating note-onset tokens when the limit is reached. + */ + optional int32 polyphony_hard_limit = 16 [(minval) = 0, (maxval) = 100]; + /* + Allows for the entropy of generation to be adjusted. When temperature=1, the probability distributions output by the model are unaltered. When temperature<1 the probability distribution is increasingly biased towards the most probable tokens. With a very small temperature value this would be equivalent to argmax sampling. When temperature>1 the probability distribution moves towards a random uniform distribution. It is recommended to keep this value close to 1 in most cases. + */ + optional float temperature = 17 [(fminval) = 0.5, (fmaxval) = 2.0]; + + + repeated int32 internal_ts_numerators = 14; + repeated int32 internal_ts_denominators = 15; + optional string internal_genre = 9; + + repeated StatusBar bars = 19; +} + +/* +The Status message specifies which bars or tracks are to be generated/conditioned on, and provides extra information about conditioning such as instrument, density, polyphony and note-duration. +*/ +message Status { + repeated StatusTrack tracks = 1; + + /* + For microtiming generation, last sampling step must be decoded using the delta resolution + */ + optional bool decode_final = 2; + optional bool full_resolution = 3; +} + +/* +The SampleParam message specifies hyper-parameters for generation. +*/ +message HyperParam { + /* + For multi-step generation (typically employed when the entire piece is too large to be considered by the model simultaneously) this parameter specifies the number of tracks that are generated in each step. + */ + optional int32 tracks_per_step = 1 [(minval) = 1, (maxval) = 12]; + /* + For multi-step generation this parameter specifies the number of bars that are generated in each step. This value should be set in relation to model_dim. If bars_per_step = model_dim, then there will be no horizontal conditioning, which will typically produce inferior results. A good rule of thumb is to use bars_per_step == model_dim / 2. + */ + optional int32 bars_per_step = 2 [(minval) = 1, (maxval) = 8]; + /* + The size of the model. In most cases this will be 4. + */ + optional int32 model_dim = 3 [(minval) = 1, (maxval) = 8]; + /* + The percentage of the selected material (selected bars in the Status message) that will be generated. + */ + optional int32 percentage = 5 [(minval) = 1, (maxval) = 100]; + /* + The number of outputs to be generated. Currently we only support batch_size=1. + With multi-step sampling its is likely more efficient to simply make several calls in series. + */ + optional int32 batch_size = 7 [(minval) = 1, (maxval) = 1]; + /* + Allows for the entropy of generation to be adjusted. When temperature=1, the probability distributions output by the model are unaltered. When temperature<1 the probability distribution is increasingly biased towards the most probable tokens. With a very small temperature value this would be equivalent to argmax sampling. When temperature>1 the probability distribution moves towards a random uniform distribution. It is recommended to keep this value close to 1 in most cases. + */ + optional float temperature = 6 [(fminval) = 0.5, (fmaxval) = 2.0]; + /* + This parameter turns on and off per-track temperature control + */ + optional bool use_per_track_temperature = 17; + /* + The max number of tokens to generate before terminating generation. Can be used to avoid memory overload. When this value is set to zero it is ignored, and no limitations are set of the number of generated tokens. + */ + optional int32 max_steps = 13 [(minval) = 0, (maxval) = 2048]; + + optional int32 polyphony_hard_limit = 14 [(minval) = 0, (maxval) = 100]; + /* + When shuffle=true the generation steps are randomly ordered. For obvious reasons, auto-regressive sampling cannot be used with shuffle=true, as it would cease to be auto-regressive. + */ + optional bool shuffle = 4; + /* + Mainly for debugging purposes. + */ + optional bool verbose = 8; + /* + The path to the ckpt, which should either be an absolute path or relative to the executable. + */ + optional string ckpt = 9; + /* + Control over probability of masking top k + */ + optional float mask_top_k = 10 [(fminval) = 0, (fmaxval) = 1.]; + /* + Control stochastic seed for reproducability + */ + optional int32 sampling_seed = 11; + + optional bool internal_skip_preprocess = 12; + optional bool internal_disable_masking = 16; + +} + diff --git a/libraries/protobuf/src/midi_internal.proto b/libraries/protobuf/src/midi_internal.proto new file mode 100644 index 0000000000000000000000000000000000000000..62afb2cec54cc4ef179a90a05c42772f34d6713b --- /dev/null +++ b/libraries/protobuf/src/midi_internal.proto @@ -0,0 +1,125 @@ +syntax = "proto2"; + +import "enum.proto"; +import "google/protobuf/descriptor.proto"; + +package midi; + +extend google.protobuf.FieldOptions { + optional int32 maxval = 50001; + optional int32 minval = 50002; + optional float fmaxval = 50003; + optional float fminval = 50004; +} + +message ContinuousFeature { + optional float av_polyphony = 1; + optional float note_duration = 3; + optional float note_duration_norm = 4; +} + +message BarFeatures { + optional int32 onset_density = 1; + optional int32 onset_polyphony_min = 2; + optional int32 onset_polyphony_max = 3; +} + +message TrackLevelAttributeControlDistributions { + repeated int32 polyphony_quantile = 1; + repeated int32 note_duration_quantile = 2; + repeated int32 note_density = 3; + repeated int32 onset_polyphony = 4; + repeated int32 onset_density = 5; + repeated int32 note_duration = 8; +} + +message TrackFeatures { + optional int32 min_pitch = 1; + optional int32 max_pitch = 2; + optional float av_polyphony = 3; + optional int32 note_density = 4; + optional int32 note_density_v2 = 5; + optional int32 max_polyphony = 6; + optional bool should_prune = 7; + optional int32 order = 8; + optional float note_duration = 9; + optional string genre_str = 10; + optional int32 min_polyphony_q = 11; + optional int32 max_polyphony_q = 12; + optional int32 min_note_duration_q = 13; + optional int32 max_note_duration_q = 14; + repeated int32 polyphony_distribution = 15; + optional float note_density_value = 16; + + optional int32 min_polyphony_hard = 18; + optional int32 max_polyphony_hard = 19; + optional int32 min_note_duration_hard = 20; + optional int32 max_note_duration_hard = 21; + + optional int32 onset_polyphony_min = 24; + optional int32 onset_polyphony_max = 25; + optional int32 onset_density = 26; + optional int32 onset_density_min = 27; + optional int32 onset_density_max = 28; + repeated int32 duration_distribution = 29; + + optional int32 genre = 32; + optional int32 note_density_level = 35; + + optional int32 contains_note_duration_thirty_second = 40; + optional int32 contains_note_duration_sixteenth = 41; + optional int32 contains_note_duration_eighth = 42; + optional int32 contains_note_duration_quarter = 43; + optional int32 contains_note_duration_half = 44; + optional int32 contains_note_duration_whole = 45; + + optional TrackLevelAttributeControlDistributions attribute_control_distributions = 30; // store them all here + +} + +message PieceFeatures { + optional string genre = 1; +} + +message Note { + optional int32 start = 1; + optional int32 end = 2; + optional int32 pitch = 3; + optional int32 tick_delta = 4; +} + +message ValidTrack { + repeated int32 tracks = 1; +} + +message Item { + optional uint64 start = 1; + optional uint64 end = 2; + optional uint64 src_size = 3; +} + +message Dataset { + repeated Item train = 1; + repeated Item valid = 2; + repeated Item test = 3; +} + +message ModelMetadata { + optional string encoder = 1; + optional int32 num_layers = 2; + optional int32 num_heads = 3; + optional int32 num_hidden = 4; + optional int32 model_dim = 5; + optional bool new_state = 6; +} + +message GenreData { + optional string discogs = 1; + optional string lastfm = 2; + optional string tagtraum = 3; +} + +message MetadataLabels { + optional GenreMusicmap genre = 1; + optional int32 nomml = 6; +} \ No newline at end of file diff --git a/libraries/protobuf/src/track_type.proto b/libraries/protobuf/src/track_type.proto new file mode 100644 index 0000000000000000000000000000000000000000..4c09cd7d2916bfeb42d4812b6673f6e7f7fde455 --- /dev/null +++ b/libraries/protobuf/src/track_type.proto @@ -0,0 +1,12 @@ +syntax = "proto2"; + +package midi; + +enum TRACK_TYPE { + AUX_DRUM_TRACK = 8; + AUX_INST_TRACK = 9; + STANDARD_TRACK = 10; + STANDARD_DRUM_TRACK = 11; + STANDARD_BOTH = 12; + NUM_TRACK_TYPES = 16; +} \ No newline at end of file diff --git a/libraries/pybind11 b/libraries/pybind11 new file mode 160000 index 0000000000000000000000000000000000000000..5ccb9e412d8974e8eeac1b061d9077ac0bd365e1 --- /dev/null +++ b/libraries/pybind11 @@ -0,0 +1 @@ +Subproject commit 5ccb9e412d8974e8eeac1b061d9077ac0bd365e1 diff --git a/libraries/torch/CMakeLists.txt b/libraries/torch/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..56f7cad5df83b2fbb7d7f007909fc090fe6c5cf3 --- /dev/null +++ b/libraries/torch/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required(VERSION 3.8) + +project(midigpt_torch) + +set(SRCS + src/torch_library.cpp + "include/torch_library.h") + +set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../libtorch/") +find_package(Torch REQUIRED) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +add_library(midigpt_torch + ${SRCS}) + +target_link_libraries(midigpt_torch PRIVATE "${TORCH_LIBRARIES}") + +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET example-app + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) \ No newline at end of file diff --git a/libraries/torch/CMakeSettings.json b/libraries/torch/CMakeSettings.json new file mode 100644 index 0000000000000000000000000000000000000000..2a7e537e09316df66b12f427deeae89a75068dcd --- /dev/null +++ b/libraries/torch/CMakeSettings.json @@ -0,0 +1,29 @@ +{ + "configurations": [ + { + "name": "x64-Debug", + "generator": "Ninja", + "configurationType": "Debug", + "inheritEnvironments": [ "msvc_x64_x64" ], + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "" + }, + { + "name": "WSL-GCC-Debug", + "generator": "Ninja", + "configurationType": "Debug", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeExecutable": "cmake", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "inheritEnvironments": [ "linux_x64" ], + "wslPath": "${defaultWSLPath}", + "variables": [] + } + ] +} \ No newline at end of file diff --git a/libraries/torch/include/torch_library.h b/libraries/torch/include/torch_library.h new file mode 100644 index 0000000000000000000000000000000000000000..28e6b1f866836315d839ae2f602559cd07181cd7 --- /dev/null +++ b/libraries/torch/include/torch_library.h @@ -0,0 +1,5 @@ +#pragma once + +namespace midigptTorch { + void testMmmTorch(); +} \ No newline at end of file diff --git a/libraries/torch/src/torch_library.cpp b/libraries/torch/src/torch_library.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c9171935c57716fbb472bd64637d8b8a62cfcce --- /dev/null +++ b/libraries/torch/src/torch_library.cpp @@ -0,0 +1,10 @@ +#include +#include "../../libtorch/include/torch/csrc/api/include/torch/torch.h" + +namespace midigptTorch { + void testMmmTorch() { + std::cout << "midigptTorch test function called" << std::endl; + torch::Tensor tensor = torch::rand({ 2, 3 }); + std::cout << tensor << std::endl; + } +} \ No newline at end of file diff --git a/midigpt_setup_helper.sh b/midigpt_setup_helper.sh new file mode 100644 index 0000000000000000000000000000000000000000..bdb8c1bc22ea2f4fea3c896eb13ea8b90c6157d3 --- /dev/null +++ b/midigpt_setup_helper.sh @@ -0,0 +1,182 @@ +#!/bin/bash + +clone=no +replace=no +test_train=no +inference=no +mac=no +python_path=/Library/Frameworks/Python.framework/Versions/3.8/bin +og=no + +usage=" +$(basename "$0") [-h] [-n] [-c] [-i] [-m] [-r] [-d +directory] +-- Script for setting up and testing the MIDI-GPT repository + +where: + -h Show this help text + -n Test the training script imports + -c Clone the github MIDI-GPT repository + -i If you wish to setup repository for inference + -d Provide directory name where repo is/will be cloned + -r Replace directory if already exists + -m If on MacOS CPU + " + + +OPTSTRING=":cnhiomrd:k:" + +while getopts ${OPTSTRING} opt; do +case ${opt} in +h) +echo "${usage}" +exit 0 +;; +i) +inference=yes +;; +n) +test_train=yes +;; +m) +mac=yes +;; +r) +replace=yes +;; +c) +clone=yes +;; +d) +repo=${OPTARG} +;; +:) +echo "Option -${OPTARG} requires an argument" +exit 1 +;; +?) +echo "Invalid option: -${OPTARG}" +exit 1 +;; +esac +done + +if test "${clone}" = "yes" +then +echo "Cloning MIDI-GPT" +fi + +echo "In directory: ${repo}" + +if test "${replace}" = "yes" +then +if [[ -d ${repo} ]] +then +echo "Directory ${repo} already exists, removing it" +rm -rf ${repo} +fi +fi + +mkdir -p ${repo} +cd ${repo} + +echo "Loading modules" + +if test "${clone}" = "yes" +then +if [[ -d MIDI-GPT ]] || [[ -d ENV ]] +then + echo "MIDI-GPT or ENV directories already exist" + exit 1 +fi +if test "${og}" = "yes" +then +{ + git clone https://www.github.com/Metacreation-Lab/MIDI-GPT.git + +} || { + echo "Cloning failed" + exit 1 +} +else +{ + git clone https://www.github.com/Metacreation-Lab/MIDI-GPT.git + +} || { + echo "Cloning failed" + exit 1 +} +fi + +${python_path}/virtualenv --no-download ./ENV + +else +if ! [[ -d MIDI-GPT ]] +then + echo "MIDI-GPT doesn't exist, try cloning the repository with the -c option" + exit 1 +fi +fi + +{ + source ./ENV/bin/activate +} || { + echo "ENV virtual environment doesn't exist" + exit 1 +} + +echo "pip installs" + +pip install --no-index --upgrade pip +pip install torch==1.13.0 +pip install transformers==4.26.1 + +cd MIDI-GPT +if test "${og}" = "no" +then +git checkout main +fi + +echo "Starting python library build" + +{ if test "${inference}" = "yes" + then + echo "Building for inference" + if test "${mac}" = "yes" + then + echo "On MacOS CPU" + bash create_python_library.sh --mac_os + else + echo "On Compute Canada" + bash create_python_library.sh --compute_canada + fi + else + echo "Building for training only" + bash create_python_library.sh --test_build --compute_canada --no_torch + fi +} || { + echo "Build failed" + exit 1 +} + +if test "${test_train}" = "yes" +then + +cd ../ + +deactivate + +echo "Activating environment" + +source $PWD/venv/bin/activate +cd $PWD/MIDI-GPT/python_scripts + +echo "Testing training script" + +python3 -c "import train" + +echo "Import tests done" + +fi + +echo "Finished" diff --git a/models/model.zip b/models/model.zip new file mode 100644 index 0000000000000000000000000000000000000000..a5a7e0660d8e056d55b9d219c0196a189dffebfa --- /dev/null +++ b/models/model.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2370946e0b45c440972ef3099be17f27cd7157b7a40ce9fd24851f2c855e4d85 +size 77523617 diff --git a/pip_requirements/common_requirements.txt b/pip_requirements/common_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..655e33c8bbf1b017fd70c03da8439985030094a7 --- /dev/null +++ b/pip_requirements/common_requirements.txt @@ -0,0 +1 @@ +pretty_midi \ No newline at end of file diff --git a/pip_requirements/create_dataset_requirements.txt b/pip_requirements/create_dataset_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..27634909ed6b150d85bd8d8a88b0895f1b0f161d --- /dev/null +++ b/pip_requirements/create_dataset_requirements.txt @@ -0,0 +1,4 @@ +attrs==23.1.0+computecanada +jsonlines==3.1.0+computecanada +numpy==1.24.2+computecanada +tqdm==4.46.1 diff --git a/pip_requirements/inference_requirements.txt b/pip_requirements/inference_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..29b4137119ea8578bf8f66bcec595019e2ea7bed --- /dev/null +++ b/pip_requirements/inference_requirements.txt @@ -0,0 +1 @@ +torch==2.0.0 \ No newline at end of file diff --git a/pip_requirements/train_requirements.txt b/pip_requirements/train_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c96b61dc78147b030461ce27f45df6619eea5d2c --- /dev/null +++ b/pip_requirements/train_requirements.txt @@ -0,0 +1,135 @@ +anyio==3.6.2+computecanada +arff==0.9+computecanada +argon2-cffi==21.3.0+computecanada +argon2-cffi-bindings==21.2.0+computecanada +asttokens==2.2.1+computecanada +async-generator==1.10+computecanada +attrs==23.1.0+computecanada +backcall==0.2.0+computecanada +backports-abc==0.5+computecanada +backports-shutil-get-terminal-size==1.0.0+computecanada +bcrypt==4.0.1+computecanada +beautifulsoup4==4.11.2+computecanada +bitstring==4.0.1+computecanada +bleach==6.0.0+computecanada +certifi==2022.12.7+computecanada +cffi==1.15.1+computecanada +chardet==5.1.0+computecanada +charset-normalizer==3.0.1+computecanada +comm==0.1.2+computecanada +contourpy==1.0.7+computecanada +cryptography==39.0.1+computecanada +cycler==0.11.0+computecanada +Cython==0.29.33+computecanada +deap==1.3.3+computecanada +debugpy==1.6.6+computecanada +decorator==5.1.1+computecanada +defusedxml==0.7.1+computecanada +dnspython==2.3.0+computecanada +ecdsa==0.18.0+computecanada +entrypoints==0.4+computecanada +executing==1.2.0+computecanada +fastjsonschema==2.16.2+computecanada +filelock==3.12.2 +fonttools==4.38.0+computecanada +fsspec==2023.6.0 +funcsigs==1.0.2+computecanada +huggingface-hub==0.15.1 +idna==3.4+computecanada +importlib-metadata==5.2.0+computecanada +importlib-resources==5.12.0+computecanada +ipykernel==6.21.2+computecanada +ipython==8.10.0+computecanada +ipython-genutils==0.2.0+computecanada +ipywidgets==8.0.4+computecanada +jedi==0.18.2+computecanada +Jinja2==3.1.2+computecanada +jsonlines==3.1.0+computecanada +jsonschema==4.17.3+computecanada +jupyter-client==8.0.3+computecanada +jupyter-core==5.2.0+computecanada +jupyter-events==0.6.3+computecanada +jupyter-server==2.3.0+computecanada +jupyter-server-terminals==0.4.4+computecanada +jupyterlab-pygments==0.2.2+computecanada +jupyterlab-widgets==3.0.5+computecanada +kiwisolver==1.4.4+computecanada +lockfile==0.12.2+computecanada +MarkupSafe==2.1.2+computecanada +matplotlib==3.7.0+computecanada +matplotlib-inline==0.1.6+computecanada +mistune==2.0.5+computecanada +mock==5.0.1+computecanada +mpmath==1.2.1+computecanada +nbclassic==0.5.2+computecanada +nbclient==0.7.2+computecanada +nbconvert==7.2.9+computecanada +nbformat==5.7.3+computecanada +nest-asyncio==1.5.6+computecanada +netaddr==0.8.0+computecanada +netifaces==0.11.0+computecanada +nose==1.3.7+computecanada +notebook==6.5.2+computecanada +notebook-shim==0.2.2+computecanada +numpy==1.24.2+computecanada +packaging==23.0+computecanada +pandas==1.5.3+computecanada +pandocfilters==1.5.0+computecanada +paramiko==3.0.0+computecanada +parso==0.8.3+computecanada +path==16.6.0+computecanada +path.py==12.5.0+computecanada +pathlib2==2.3.7+computecanada +paycheck==1.0.2+computecanada +pbr==5.11.1+computecanada +pexpect==4.8.0+computecanada +pickleshare==0.7.5+computecanada +Pillow==9.4.0+computecanada +pkgutil-resolve-name==1.3.10+computecanada +platformdirs==2.5.2+computecanada +prometheus-client==0.16.0+computecanada +prompt-toolkit==3.0.37+computecanada +protobuf==4.23.3 +psutil==5.9.4+computecanada +ptyprocess==0.7.0+computecanada +pure-eval==0.2.2+computecanada +pycparser==2.21+computecanada +Pygments==2.14.0+computecanada +PyNaCl==1.5.0+computecanada +pyparsing==3.0.9+computecanada +pyrsistent==0.19.3+computecanada +python-dateutil==2.8.2+computecanada +python-json-logger==2.0.7+computecanada +pytz==2022.7.1+computecanada +PyYAML==6.0+computecanada +pyzmq==25.0.0+computecanada +regex==2022.10.31+computecanada +requests==2.28.2+computecanada +rfc3339-validator==0.1.4+computecanada +rfc3986-validator==0.1.1+computecanada +scipy==1.10.1+computecanada +send2trash==1.8.0+computecanada +simplegeneric==0.8.1+computecanada +singledispatch==4.0.0+computecanada +six==1.16.0+computecanada +sniffio==1.3.0+computecanada +soupsieve==2.4+computecanada +stack-data==0.6.2+computecanada +sympy==1.11.1+computecanada +tensorboardX==2.2+computecanada +terminado==0.17.1+computecanada +testpath==0.6.0+computecanada +tinycss2==1.2.1+computecanada +tokenizers==0.13.2+computecanada +torch==1.13.1+computecanada +tornado==6.2+computecanada +tqdm==4.46.1 +traitlets==5.9.0+computecanada +transformers==4.26.1+computecanada +typing-extensions==4.5.0+computecanada +urllib3==1.26.14+computecanada +wcwidth==0.2.6+computecanada +webencodings==0.5.1+computecanada +websocket-client==1.5.1+computecanada +widgetsnbextension==4.0.5+computecanada +zipp==3.14.0+computecanada diff --git a/python_scripts/config/bert.json b/python_scripts/config/bert.json new file mode 100644 index 0000000000000000000000000000000000000000..9f1d814831f5822d821d9bdf54610f651493c175 --- /dev/null +++ b/python_scripts/config/bert.json @@ -0,0 +1,7 @@ +{ + "hidden_size" : 512, + "num_hidden_layers" : 6, + "nuim_attention_heads" : 8, + "intermediate_size" : 2048, + "max_position_embeddings" : 2048 +} \ No newline at end of file diff --git a/python_scripts/config/bert_tiny.json b/python_scripts/config/bert_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..be2821a8623ea523dc264cfba6611c3553ced8d6 --- /dev/null +++ b/python_scripts/config/bert_tiny.json @@ -0,0 +1,7 @@ +{ + "hidden_size" : 128, + "num_hidden_layers" : 2, + "num_attention_heads" : 2, + "intermediate_size" : 128, + "max_position_embeddings" : 2048 +} \ No newline at end of file diff --git a/python_scripts/config/gpt2.json b/python_scripts/config/gpt2.json new file mode 100644 index 0000000000000000000000000000000000000000..410b442fd4007e8230a3a4100817d7cc0927f509 --- /dev/null +++ b/python_scripts/config/gpt2.json @@ -0,0 +1,7 @@ +{ + "n_positions": 2048, + "n_ctx": 2048, + "n_layer": 6, + "n_head": 8, + "n_embd": 512 +} diff --git a/python_scripts/config/gpt2_tiny.json b/python_scripts/config/gpt2_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..a81aba5b59fe2862a68c5fea22239a97aba18a01 --- /dev/null +++ b/python_scripts/config/gpt2_tiny.json @@ -0,0 +1,7 @@ +{ + "n_positions": 2048, + "n_ctx": 2048, + "n_layer": 2, + "n_head": 2, + "n_embd": 128 +} \ No newline at end of file diff --git a/python_scripts/convert.py b/python_scripts/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5fbf892a96040b08f5421ed499ddfb0917c8e6 --- /dev/null +++ b/python_scripts/convert.py @@ -0,0 +1,204 @@ +# take a trained pytorch model and convert it +import os +import sys +sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") +print( os.path.dirname(os.getcwd()) + "/python_lib" ) +import midigpt +import time +import json +import numpy as np +import torch +import torch.quantization +from transformers import GPT2LMHeadModel, GPT2Config +from transformers.modeling_utils import Conv1D + +from custom_models import * + +from torch import nn + +class QuantWrapper(nn.Module): + def __init__(self, module): + super(QuantWrapper, self).__init__() + qconfig = module.qconfig if hasattr(module, 'qconfig') else None + self.add_module('quant', torch.quantization.QuantStub(qconfig)) + self.add_module('dequant', torch.quantization.DeQuantStub()) + self.add_module('module', module) + self.train(module.training) + + def forward(self, X, P): + X = self.quant(X) + P = self.quant(P) + O = self.module(X,P) + return self.dequant(O) + +def _conv1d_to_linear(module): + in_size, out_size = module.weight.shape + linear = torch.nn.Linear(in_size, out_size) + linear.weight.data = module.weight.data.T.contiguous() + linear.bias.data = module.bias.data + return linear + +def conv1d_to_linear(model): + for name in list(model._modules): + module = model._modules[name] + if isinstance(module, Conv1D): + linear = _conv1d_to_linear(module) + model._modules[name] = linear + else: + conv1d_to_linear(module) + +def score_model(model): + targets = np.load("target.npz")["data"] + + +def time_model(model): + start = time.time() + pkv = None + for _ in range(1000): + input_ids = torch.ones(1,1).type(torch.LongTensor) + outputs = model(input_ids, past_key_values=pkv) + pkv = outputs[1] + print("BATCH TIME : {}".format(time.time() - start)) + +def print_size_of_model(model): + import os + torch.save(model.state_dict(), "temp.p") + print('Size (MB):', os.path.getsize("temp.p")/1e6) + os.remove('temp.p') + +def quantize_model(model): + conv1d_to_linear(model) + model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8) + return model + +def static_quantize_model(model): + conv1d_to_linear(model) + model.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(model, inplace=True) + torch.quantization.convert(model, inplace=True) + + return model + +def prune_model(model): + import torch.nn.utils.prune as prune + + conv1d_to_linear(model) + parameters_to_prune = [] + for _,module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + prune.l1_unstructured(module, name="weight", amount=.8) + prune.remove(module, "weight") + + for _,submodule in module.named_modules(): + if isinstance(submodule, torch.nn.Linear): + prune.l1_unstructured(submodule, name="weight", amount=.8) + prune.remove(submodule, "weight") + + return model + +def inject_metadata(path, metadata_path, encoder, new_state): + model = torch.jit.load(path) + with open(metadata_path, "r") as f: + metadata = json.load(f) + metadata["encoder"] = encoder + metadata["new_state"] = new_state + extra_files = torch._C.ExtraFilesMap() + extra_files['metadata.json'] = json.dumps(metadata) + out_path = os.path.splitext(path)[0] + "_WMETA.pt" + torch.jit.save(model, out_path, _extra_files=extra_files) + +def convert(model, path, quantize=False, prune=False, force=False, control=False, ckpt_path=None, encoderX=None): + if not os.path.exists(path) or force: + model.eval() + if quantize: + model = quantize_model(model) + if prune: + model = prune_model(model) + print_size_of_model(model) + example_input = torch.zeros(1,300).type(torch.LongTensor) + example_control = torch.zeros(1,300,3).type(torch.FloatTensor) + if control: + outputs = model(input_ids=example_input, control_ids=example_control, past_key_values=None) + print(len(outputs[1])) + traced_script_module = torch.jit.trace(model, [example_input,example_control,outputs[1]]) + else: + outputs = model(input_ids=example_input) + traced_script_module = torch.jit.trace(model, [example_input, outputs[1]]) + + num_layers = len(outputs[1]) + _,num_heads,_,num_hidden = outputs[1][0][0].detach().numpy().shape + encoder = encoderX + + model_metadata = { + "encoder" : encoder, + "num_heads" : num_heads, + "num_hidden" : num_hidden, + "num_layers" : num_layers, + "model_dim" : -1, + "new_state" : True + } + + print(model_metadata) + + extra_files = {} + extra_files['metadata.json'] = json.dumps(model_metadata) + torch.jit.save( + traced_script_module, path, _extra_files=extra_files) + + +class GPT2LMHeadModelWMeta(GPT2LMHeadModel): + def extra_repr(self): + return "trent is the man" + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--output", type=str, default="") + parser.add_argument("--metadata_path", type=str, default="") + parser.add_argument("--config", type=str, default="") + parser.add_argument("--encoder", type=str, default="NONE") + parser.add_argument("--init", action="store_true") + parser.add_argument("--inject", action="store_true") + parser.add_argument("--new_state", action="store_true") + parser.add_argument("--quantize", action="store_true") + parser.add_argument("--prune", action="store_true") + parser.add_argument("--control", action="store_true") + + args = parser.parse_args() + + + if args.inject: + assert len(args.metadata_path) + inject_metadata( + args.ckpt_path, args.metadata_path, args.encoder, True if args.new_state else False) + + else: + assert len(args.output) + if args.init: + encoder_mode = midigpt.getEncoderType(args.encoder) + assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER + encoder = midigpt.getEncoder(encoder_mode) + vocab_size = encoder.vocab_size() + if args.control: + config = GPT2LMHeadModelContConfig().from_json_file(args.config) + # encoder knows the size of the embedding + config.n_control_dim = encoder.config.embed_dim + model_cls = GPT2LMHeadModelCont + + else: + config = GPT2Config().from_json_file(args.config) + config.vocab_size = vocab_size + model_cls = GPT2LMHeadModel + + model = model_cls(config) + else: + if args.control: + model = GPT2LMHeadModelCont.from_pretrained(args.ckpt_path, torchscript=True) + else: + model = GPT2LMHeadModel.from_pretrained(args.ckpt_path, torchscript=True) + + convert(model, args.output, quantize=args.quantize, prune=args.prune, control=args.control, ckpt_path=args.ckpt_path, encoderX=args.encoder) + diff --git a/python_scripts/create_dataset.py b/python_scripts/create_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b4177d0ba830c13d0e4f323dcd32c24377ef01fd --- /dev/null +++ b/python_scripts/create_dataset.py @@ -0,0 +1,226 @@ +import os +import glob +import json +import numpy as np +import csv +from tqdm import tqdm +from multiprocessing import Pool + +from utils import * + +import sys +import os +sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") +import midigpt + +def worker(args): + path,sid,labels,nomml,tcjson,encoding = args + tc = midigpt.TrainConfig() + tc.from_json(tcjson) + labels["nomml"] = nomml + + encoder_mode = midigpt.getEncoderType(encoding) + assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER + encoder = midigpt.getEncoder(encoder_mode) + + try: + return sid, midigpt.midi_to_json_bytes(path,tc,json.dumps(labels)) + except Exception as e: + print(e) + return None,None + +def load_json(path): + if not os.path.exists(path): + return {} + with open(path, "r") as f: + return json.load(f) + +DEFAULT_LABELS = { + "genre": "GENRE_MUSICMAP_ANY", + "valence_spotify": -1, + "energy_spotify": -1, + "danceability_spotify": -1, + "tension": [] +} + +DATA_TYPES = [ + "Drum", + "Drum+Music", + "Music-No-Drum" +] + +def load_metadata_labels(genre_data_path, spotify_data_path, tension_data_path): + data = {} + genre_data = load_json(genre_data_path) + spotify_data = load_json(spotify_data_path) + tension_data = load_json(tension_data_path) + md5s = list(set(list(genre_data.keys()) + list(spotify_data.keys()) + list(tension_data.keys()))) + for md5 in md5s: + data[md5] = {} + if md5 in spotify_data: + data[md5]["valence_spotify"] = np.mean(spotify_data[md5]["valence"]) + data[md5]["energy_spotify"] = np.mean(spotify_data[md5]["energy"]) + data[md5]["danceability_spotify"] = np.mean(spotify_data[md5]["danceability"]) + else: + for k,v in DEFAULT_LABELS.items(): + data[md5][k] = v + data[md5]["genre"] = genre_data.get(md5, DEFAULT_LABELS["genre"]) + data[md5]["tension"] = tension_data.get(md5, DEFAULT_LABELS["tension"]) + return data + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--num_bars", type=int, default=4) + parser.add_argument("--expressive", action="store_true") + parser.add_argument("--ignore_score", type=bool, default=0) + parser.add_argument("--nthreads", type=int, default=8) + parser.add_argument("--max_size", type=int, default=-1) + parser.add_argument("--genre_data", type=str, default="") + parser.add_argument("--spotify_data", type=str, default="") + parser.add_argument("--tension_data", type=str, default="") + parser.add_argument("--encoding", type=str, default="TRACK_ENCODER") + parser.add_argument("--resolution", type=int, default=12) + parser.add_argument("--delta_resolution", type=int, default=1920) + parser.add_argument("--metadata", type=str, required=True) + parser.add_argument("--type", type=str, default="Drum+Music") + parser.add_argument("--test", type=str, default="no") + args = parser.parse_args() + + args.ignore_score = bool(args.ignore_score) + if args.test != "no": + test_script = True + else: + test_script = False + + assert args.type in DATA_TYPES + args.type = "-" + args.type + "-" + + import os + os.system("taskset -p 0xffff %d" % os.getpid()) + + # multi thread approach takes about 2 minutes + pool = Pool(args.nthreads) + output = os.path.splitext(args.output)[0] + ss="" + if args.max_size > 0: + ss=f"_MAX_{args.max_size}" + if args.expressive: + output += "/{}_NUM_BARS={}_RESOLUTION_{}_DELTA_{}{}.arr".format(args.encoding,args.num_bars,args.resolution, args.delta_resolution,ss) + else: + output += "/{}_NUM_BARS={}_RESOLUTION_{}{}.arr".format(args.encoding,args.num_bars,args.resolution,ss) + print(output) + if not test_script: + jag = midigpt.BytesToFile(output) + + + paths = list(glob.glob(args.data_dir + "/**/*.mid", recursive=True)) + + import random + import time + random.seed(int(time.time())) + + tc = midigpt.TrainConfig() + tc.num_bars = args.num_bars + tc.use_microtiming = args.expressive + tc.resolution = args.resolution + tc.delta_resolution = args.delta_resolution + tc = tc.to_json() + print(tc) + + paths_exp = [] + sids_exp = [] + paths_non_exp = [] + sids_non_exp = [] + paths_all = [] + sids_all = [] + nomml_alls = [] + nomml_scores = [] + + try: + with open(args.metadata) as meta: + reader = csv.DictReader(meta, delimiter=',') + for row in tqdm(reader): + path = row["filepath"] + nomml = int(row["medianMetricDepth"]) + if (".mid" in path and args.type in path): + if "-Train-" in path: + group = 0 + elif "-Val-" in path: + group = 1 + elif "-Test-" in path: + group = 2 + else: + raise RuntimeError("data format incorrect") + if (nomml < 12): + paths_non_exp.append(os.path.join(args.data_dir,path)) + sids_non_exp.append(group) + nomml_scores.append(nomml) + else: + paths_exp.append(os.path.join(args.data_dir,path)) + sids_exp.append(group) + paths_all.append(os.path.join(args.data_dir,path)) + sids_all.append(group) + nomml_alls.append(nomml) + + except: + paths_all = list(glob.glob(args.data_dir + "/**/*.mid", recursive=True)) + for path in paths_all: + if "-train-" in path: + sids_all.append(0) + elif "-valid-" in path: + sids_all.append(1) + elif "-test-" in path: + sids_all.append(2) + else: + raise RuntimeError("data format incorrect") + + nomml_vals = [] + if args.expressive: + if args.ignore_score: + paths = paths_exp + sids = sids_exp + nomml_vals = [12 for _ in sids] + else: + paths = paths_all + sids = sids_all + nomml_vals = nomml_alls + else: + paths = paths_all + sids = sids_all + nomml_vals = nomml_alls + + metadata_label_data = load_metadata_labels(args.genre_data, args.spotify_data, args.tension_data) + metadata_labels = [metadata_label_data.get(os.path.splitext(os.path.basename(p))[0],DEFAULT_LABELS) for p in paths] + print("LOADED {} METADATA LABELS".format(len(metadata_labels))) + + tcs = [tc for _ in paths] + encoding = [args.encoding for _ in paths] + inputs = list(zip(paths,sids,metadata_labels,nomml_vals,tcs,encoding)) + random.shuffle(inputs) + + for k,v in DEFAULT_LABELS.items(): + print("{} FILES HAVE {} METADATA".format(sum([m[k] != v for m in metadata_labels]),k)) + + if args.max_size > 0: + inputs = inputs[:args.max_size] + + if not test_script: + total_count = 0 + success_count = 0 + pool = Pool(args.nthreads) + progress_bar = tqdm(pool.imap_unordered(worker, inputs), total=len(inputs)) + for sid,b in progress_bar: + if b is not None and len(b): + jag.append_bytes_to_file_stream(b,sid) + success_count += 1 + total_count += 1 + status_str = "{}/{}".format(success_count,total_count) + progress_bar.set_description(status_str) + jag.close() + else: + print("Test successful") + sys.exit(0) diff --git a/python_scripts/custom_models.py b/python_scripts/custom_models.py new file mode 100644 index 0000000000000000000000000000000000000000..1af1bd3c046937cff214d3bee1c4e3eaa080fd3d --- /dev/null +++ b/python_scripts/custom_models.py @@ -0,0 +1,121 @@ +from transformers import * +import torch.nn as nn +import torch + +class GPT2Encoder(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, 128, bias=False) + self.init_weights() + self.model_parallel = False + self.device_map = None + + negative_importance = torch.tensor(5.33).float() + negative_threshold = torch.tensor(4.).float() + entropy_importance = torch.tensor(0.05).float() + self.register_buffer('negative_importance', negative_importance) + self.register_buffer('negative_threshold', negative_threshold) + self.register_buffer('entropy_importance', entropy_importance) + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + sequence_lengths=None, + ): + assert sequence_lengths is not None + outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + logits = self.score(hidden_states) + return logits[range(input_ids.shape[0]),sequence_lengths-1] + +class GPT2LMHeadModelContConfig(GPT2Config): + def __init__( + self, + n_control_embd=64, + n_control_dim=3, + **kwargs + ): + super().__init__(**kwargs) + self.n_control_embd = n_control_embd + self.n_control_dim = n_control_dim + +class GPT2LMHeadModelCont(GPT2LMHeadModel): + def __init__(self, config): + super().__init__(config) + token_embd = config.n_embd - config.n_control_embd + self.wte = nn.Embedding(config.vocab_size, token_embd) + self.ctrle = nn.Linear(config.n_control_dim, config.n_control_embd) + + def forward( + self, + input_ids=None, + control_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + sequence_lengths=None + ): + shape = control_ids.shape + input_shape = (shape[0]*shape[1],shape[2]) + output_shape = (shape[0],shape[1],self.config.n_control_embd) + control_embd = self.ctrle(torch.reshape(control_ids, input_shape)) + control_embd = torch.reshape(control_embd, output_shape) + token_embd = self.wte(input_ids) + inputs_embeds = torch.cat([token_embd,control_embd], axis=-1) + return super().forward( + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=labels + ) + +if __name__ == "__main__": + + batch_size = 3 + + config = GPT2Config().from_json_file("config/gpt2_tiny.json") + + #model = GPT2LMHeadModelCont(config) + model = GPT2Encoder(config) + + kwargs = { + "input_ids" : torch.randint(config.vocab_size, size=(batch_size,100)), + "labels" : torch.randint(config.vocab_size, size=(batch_size,100)), + #"control_ids" : torch.randint(CONTROL_VOCAB_SIZE, size=(1,100)) + "sequence_lengths" : [99,90,90] + } + + out = model.forward(**kwargs) + print(out.shape) diff --git a/python_scripts/data_split.py b/python_scripts/data_split.py new file mode 100644 index 0000000000000000000000000000000000000000..43b46ca8b27da0343277df9a51387099133054fa --- /dev/null +++ b/python_scripts/data_split.py @@ -0,0 +1,54 @@ +from os import listdir +from os.path import isfile, join +from shutil import move +import random +import sys +from tqdm import tqdm + +if __name__ == "__main__": + + root = sys.argv[1] + new_root = sys.argv[2] + + # Get all midi files + print("Getting MIDI files") + onlyfiles = [f for f in listdir(root)] + n = len(onlyfiles) + print("Num files: " + str(n)) + + # Generate random test/train/valid indices + + idx = [i for i in range(n)] + split_idx = random.shuffle(idx) + + train_len = int(0.8 * n) + test_len = int(0.1 * n) + valid_len = n - train_len - test_len + + train_idx = idx[:train_len] + test_idx = idx[train_len:test_len + train_len] + valid_idx = idx[test_len + train_len:] + + # Move files to respective folder + + o = 0 + + print('Spliting Train Set') + for i in tqdm(range(train_len)): + move(join(root, onlyfiles[train_idx[i]]), join(new_root, "train", onlyfiles[train_idx[i]])) + o += 1 + + print('Spliting Test Set') + for i in tqdm(range(test_len)): + move(join(root, onlyfiles[test_idx[i]]), join(new_root, "test", onlyfiles[test_idx[i]])) + o += 1 + + print('Spliting Validation Set') + for i in tqdm(range(valid_len)): + move(join(root, onlyfiles[valid_idx[i]]), join(new_root, "valid", onlyfiles[valid_idx[i]])) + o += 1 + + print("Succes Test: " + str(o == n)) + + print(join(new_root, "valid", onlyfiles[valid_idx[100]])) + print(isfile(join(new_root, "valid", onlyfiles[valid_idx[100]]))) \ No newline at end of file diff --git a/python_scripts/losses.py b/python_scripts/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..8120329023496dedc9a4fce88aec4213a086dd0d --- /dev/null +++ b/python_scripts/losses.py @@ -0,0 +1,37 @@ +import torch + +def standard_loss(self, model, inputs, return_outputs=False): + outputs = model(**inputs) + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + loss = outputs[0].mean() + print("loss : ", loss) + return (loss,outputs) if return_outputs else loss + +def hinge_cost(m, a, b): + dist = m - torch.sqrt(torch.sum((a - b)**2, axis=1)) + return torch.mean(torch.clamp(dist,0,float('inf'))**2) + +def sim_metric_loss(self, model, inputs, return_outputs=False): + + # single pass version + batch_size = len(inputs["input_ids"])//4 + outputs = model(**inputs) + x_p = outputs[:batch_size] + x_n = outputs[batch_size:2*batch_size] + y_p = outputs[2*batch_size:3*batch_size] + y_n = outputs[3*batch_size:] + + model_attr = model + if isinstance(model, torch.nn.DataParallel): + model_attr = model.module + + cost_p = torch.mean(torch.sum((x_p - y_p)**2, axis=1)) + cost_n = model_attr.negative_importance*hinge_cost( + model_attr.negative_threshold, x_n, y_n) + cost_e = model_attr.entropy_importance*torch.mean( + torch.sum(x_p**2, axis=1) + torch.sum(y_p**2, axis=1)) + loss = cost_p + cost_n + cost_e + + print(loss) + return (loss,None) if return_outputs else loss diff --git a/python_scripts/train.py b/python_scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe2acc671ef0429fd24d3dca15620fda896467a --- /dev/null +++ b/python_scripts/train.py @@ -0,0 +1,224 @@ +import sys +import os +sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") +import midigpt + +from transformers import * + +import os +import json +import time +import torch + +import datetime +import numpy as np +from tqdm import tqdm + +from subprocess import check_output + +from losses import sim_metric_loss, standard_loss +from custom_models import * +from train_dataset import * +from transformers import Trainer, TrainingArguments +from callbacks import MemoryUsageCallback, ProfilerCallback + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--arch", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--encoding", type=str, required=True) + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--pad_value", type=int, default=-100) + + parser.add_argument("--expressive", action="store_true") + parser.add_argument("--num_bars", type=int, default=4) + parser.add_argument("--min_tracks", type=int, default=2) + parser.add_argument("--max_tracks", type=int, default=12) + parser.add_argument("--max_seq_len", type=int, default=2048) + parser.add_argument("--no_max_length", type=int, default=0) + parser.add_argument("--resolution", type=int, default=12) + parser.add_argument("--delta_resolution", type=int, default=1920) + parser.add_argument("--abs_pos_vocab_size", type=int, default=196) + parser.add_argument("--delta_vocab_size", type=int, default=96) + + parser.add_argument("--ngpu", type=int, default=4) + parser.add_argument("--accum_steps", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--batches_per_epoch", type=int, default=1000) + parser.add_argument("--lr", type=float, default=1e-4) + + parser.add_argument("--overwrite", type=int, default=1) + parser.add_argument("--save_steps", type=int, default=5000) + parser.add_argument("--log_steps", type=int, default=100) + parser.add_argument("--step", type=int, default=0) + parser.add_argument("--label", type=str, default="version3") + parser.add_argument("--profiler_steps", type=int, default=50) + + parser.add_argument("--dry", action="store_true") + parser.add_argument("--metric", action="store_true") + + parser.add_argument("--ckpt", type=str, default="") + parser.add_argument("--ckpt_num", type=int, default=5000) + parser.add_argument("--output", type=str, default="") + parser.add_argument("--log", type=str, default="") + + parser.add_argument("--test_only", action="store_true") + parser.add_argument("--memory_metrics", action="store_true") + + args = parser.parse_args() + args.expressive = (args.encoding == "EXPRESSIVE_ENCODER") and args.expressive + + dataset_cls = CustomDataset + loss_fn = standard_loss + + np.random.seed(int(time.time())) + + # determine vocab size + date_str = datetime.datetime.now().strftime('%b_%d_%H_%M') + encoder_mode = midigpt.getEncoderType(args.encoding) + assert encoder_mode is not midigpt.ENCODER_TYPE.NO_ENCODER + encoder = midigpt.getEncoder(encoder_mode) + if args.expressive: + encoder.set_scheme(args.resolution, args.delta_resolution, args.delta_vocab_size, args.abs_pos_vocab_size) + vocab_size = encoder.vocab_size() + + current_git_commit_hash = check_output(["git", "rev-parse", "HEAD"], text=True).strip() + + load_checkpoint = False + if args.ckpt == "": + name = "_".join([args.encoding, args.arch, args.label, date_str, "num_bars", str(args.num_bars), str(args.max_tracks), "GIT_HASH", current_git_commit_hash]) + else: + name = args.ckpt + load_checkpoint = True + + if args.dry: + while True: + dataset = dataset_cls(split_id=0, is_training=True, **vars(args)) + for batch in tqdm(dataset,smoothing=0): + np_inputs = batch["input_ids"].detach().numpy() + print( [encoder.pretty(t) for t in np_inputs[0][:100]] ) + print( {k:v.shape for k,v in batch.items()} ) + + if os.getenv("SLURM_TMPDIR") is not None: + # we are on compute canada and should attempt to copy + # dataset to tmpdir for faster access + from shutil import copyfile + tmpdir = os.getenv("SLURM_TMPDIR") + dataset_path = os.path.join(tmpdir, os.path.basename(args.dataset)) + if not os.path.exists(dataset_path): + copyfile(args.dataset, dataset_path) + copyfile(args.dataset + ".header", dataset_path + ".header") + args.dataset = dataset_path + + # setup datasets + train_dataset = dataset_cls(split_id=0, is_training=True, **vars(args)) + eval_dataset = dataset_cls(split_id=2, is_training=False, overload_batches_per_epoch=1, **vars(args)) + Trainer.get_train_dataloader = lambda *_args,**_kwargs: train_dataset + Trainer.get_eval_dataloader = lambda *_args,**_kwargs: eval_dataset + Trainer.compute_loss = loss_fn + + print("MODEL NAME : " + name) + print("VOCAB SIZE : " + str(vocab_size)) + print("ARGS : " + json.dumps(vars(args),indent=4)) + print("MODEL CONFIG : " + json.dumps(json.load(open(args.config,"r")),indent=4)) + print("ENCODER CONFIG : " + json.dumps(encoder.config.ToJson(),indent=4)) + + logging_dir = os.path.join(args.log, "{}".format(name)) + output_dir = os.path.join(args.output, "checkpoints/{}".format(name)) + + print("LOGGING PATH : " + logging_dir) + print("OUTPUT PATH : " + output_dir) + + os.makedirs(logging_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + + # ================================================================= + # model selection + + if args.arch == "gpt2": + config = GPT2Config().from_json_file(args.config) + model_cls = GPT2LMHeadModel + elif args.arch == "xl": + config = TransfoXLConfig().from_json_file(args.config) + model_cls = TransfoXLLMHeadModel + elif args.arch == "metric": + config = GPT2Config().from_json_file(args.config) + model_cls = GPT2Encoder + elif args.arch == "control": + config = GPT2LMHeadModelContConfig().from_json_file(args.config) + # encoder knows the size of the embedding + config.n_control_dim = encoder.config.embed_dim + model_cls = GPT2LMHeadModelCont + elif args.arch == "bert": + config = BertConfig().from_json_file(args.config) + model_cls = BertForMaskedLM + else: + raise NotImplementedError + + config.vocab_size = vocab_size + print("MODEL CONFIG : " + str(config)) + + + if len(args.ckpt.strip()) == 0: + print('Model initialization') + ckpt_path = None + model = model_cls(config) + else: + try: + print('Trying to load checkpoint') + ckpt_path = os.path.join(output_dir, f"checkpoint-{args.ckpt_num}") + model = model_cls.from_pretrained(ckpt_path) + except Exception as e: + print(e) + print('Returning to default model initialization') + model = model_cls(config) + + + # Create Memory metrics callback + + # ================================================================= + # training + + training_args = TrainingArguments( + logging_dir=logging_dir, + report_to="tensorboard", + output_dir=output_dir, + overwrite_output_dir=bool(args.overwrite), + num_train_epochs=(500000/args.batches_per_epoch)*args.accum_steps, + logging_steps=args.log_steps, + save_steps=args.save_steps, + save_total_limit=None, + learning_rate=args.lr, + gradient_accumulation_steps=args.accum_steps, + per_device_train_batch_size=args.batch_size//args.ngpu//args.accum_steps, + per_device_eval_batch_size=args.batch_size//args.ngpu//args.accum_steps, + evaluation_strategy="epoch", + prediction_loss_only=True, + skip_memory_metrics=True + ) + + # For custom memory metrics, don't work and multiply by 100 training time!!! + if args.memory_metrics: + callbacks = [MemoryUsageCallback, ProfilerCallback] + else: + callbacks = [] + + trainer = Trainer( + model=model, + args=training_args, + data_collator=None, + train_dataset=None, + eval_dataset=None, + callbacks=callbacks + ) + + trainer.train_dataset = train_dataset + trainer.eval_dataset = eval_dataset + + if not args.test_only: + trainer.train(ckpt_path) + else: + model = trainer._wrap_model(trainer.model) + model.save_pretrained(output_dir) diff --git a/python_scripts/train_dataset.py b/python_scripts/train_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6a8fa1b906b0f158abf0e4b2f66fa4483cf96d --- /dev/null +++ b/python_scripts/train_dataset.py @@ -0,0 +1,96 @@ +#from transformers import Trainer, TrainingArguments + +import os +import json +import time +import torch +import tqdm +#from torch.utils.data import Dataset + +import datetime +import numpy as np + +import sys +sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") +import midigpt + +class CustomDataset: + def __init__(self, split_id=0, is_training=True, batch_size=32, dataset=None, num_bars=4, min_tracks=2, max_tracks=12, max_seq_len=2048, expressive=False, no_max_length=False, resolution=12, encoding=None, pad_value=-100, arch="gpt2", accum_steps=1, batches_per_epoch=1000, overload_batches_per_epoch=None, **kwargs): + # settings + self.is_training = is_training + self.batch_size = batch_size // accum_steps + self.split_id = split_id + self.max_seq_len = max_seq_len + self.batches_per_epoch = batches_per_epoch if overload_batches_per_epoch is None else overload_batches_per_epoch + self.dataset = list(range(self.batches_per_epoch)) # number of examples ?? + self.pad_value = pad_value + self.arch = arch + + # create dataloader + self.dataloader = midigpt.Jagged(dataset) + self.dataloader.set_num_bars(num_bars) + self.dataloader.set_min_tracks(min_tracks) + self.dataloader.set_max_tracks(max_tracks) + self.dataloader.set_max_seq_len(max_seq_len) + seed = np.random.randint(2**20) + self.dataloader.set_seed(seed) + self.encoder_mode = midigpt.getEncoderType(encoding) + + # create train_config + self.tc = midigpt.TrainConfig() + self.tc.num_bars = num_bars + self.tc.min_tracks = min_tracks + self.tc.max_tracks = max_tracks + self.tc.use_microtiming = expressive + self.tc.no_max_length = no_max_length + self.tc.resolution = resolution + + self.current = 0 + + def _get_batch(self): + input_ids, mask = self.dataloader.read_batch_v2( + self.batch_size, self.split_id, self.encoder_mode, self.tc) + input_ids = np.array(input_ids) + mask = np.array(mask) + labels = np.copy(input_ids) + labels += (1-mask) * self.pad_value # set masked tokens to pad_value + batch = { + "input_ids" : torch.from_numpy(input_ids), + "attention_mask" : torch.from_numpy(mask), + "labels" : torch.from_numpy(labels) + } + if self.arch == "xl": + batch.pop("attention_mask") + assert np.all(np.sum(mask,axis=1)==self.max_seq_len) + if self.arch == "bert": + batch.pop("labels") + return batch + + def _get_batch_test(self): + inputs = torch.ones((32,800), dtype=torch.int64) + return { + "input_ids" : inputs, + "labels" : inputs + } + + def __iter__(self): + self.current = 0 + return self + + def __next__(self): + self.current += 1 + if self.current <= self.batches_per_epoch: + while True: + try: + return self._get_batch() + except Exception as e: + print("ERROR IN BATCHER : ", e) + raise StopIteration + + def __len__(self): + return self.batches_per_epoch + +def pad(seqs, pad_value): + seqlens = np.array([len(seq) for seq in seqs]) + maxlen = np.max(seqlens) + return np.array([np.pad(seq, (0,maxlen-len(seq)), mode="constant", constant_values=pad_value) for seq in seqs]), seqlens \ No newline at end of file diff --git a/python_scripts/utils.py b/python_scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..164dc12767232591c2d471b409aa0dfdc1c5ea25 --- /dev/null +++ b/python_scripts/utils.py @@ -0,0 +1,41 @@ +import json +import jsonlines +from multiprocessing import Pool +from tqdm import tqdm + +def load_json(path): + with open(path,"r") as f: + return json.load(f) + +def dump_json(x, path): + with open(path,"w") as f: + json.dump(x,f,indent=4) + +def apply_func(x, func): + pool = Pool(8) + for inval,outval in tqdm(pool.imap_unordered(func,x),total=len(x)): + yield inval,outval + +def load_jsonl(path,max_items=None): + with jsonlines.open(path) as reader: + for ii,item in enumerate(reader): + yield item + if max_items is not None and ii >= max_items: + break + +def dump_jsonl(data, path): + assert isinstance(data, list) + with jsonlines.open(path, mode="w") as wr: + for item in tqdm(data,leave=False): + wr.write(item) + +class dump_jsonl_multistage: + def __init__(self, path, mode="a"): + self.wr = jsonlines.open(path, mode=mode, flush=True) + def add(self, item): + self.wr.write(item) + def extend(self, items): + for item in items: + self.add(item) + def close(self): + self.wr.close() diff --git a/python_scripts_for_testing/midigpt_gen.mid b/python_scripts_for_testing/midigpt_gen.mid new file mode 100644 index 0000000000000000000000000000000000000000..b43d303cebea05fdb8a82342f1eb8801073a0e1d Binary files /dev/null and b/python_scripts_for_testing/midigpt_gen.mid differ diff --git a/python_scripts_for_testing/mtest.mid b/python_scripts_for_testing/mtest.mid new file mode 100644 index 0000000000000000000000000000000000000000..47622a54c66e981faf3e6cc949776ffc72e7df30 Binary files /dev/null and b/python_scripts_for_testing/mtest.mid differ diff --git a/python_scripts_for_testing/pythoninferencetest.py b/python_scripts_for_testing/pythoninferencetest.py new file mode 100644 index 0000000000000000000000000000000000000000..32df41f7784463aa55cf7a0c0922dcc888e44d50 --- /dev/null +++ b/python_scripts_for_testing/pythoninferencetest.py @@ -0,0 +1,67 @@ +import sys, os +sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") +import midigpt +import json +import random + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--midi", type=str, required=True) + parser.add_argument("--ckpt", type=str, required=True) + parser.add_argument("--out", type=str, default='') + args = parser.parse_args() + + ckpt = args.ckpt + midi_input = args.midi + if args.out != '': + midi_dest = args.out + else: + midi_dest = os.path.join(os.path.split(midi_input)[0], 'midigpt_gen.mid') + e = midigpt.ExpressiveEncoder() + midi_json_input = json.loads(e.midi_to_json(midi_input)) + valid_status={'tracks': + [ + { + 'track_id': 0, + 'temperature' : 0.5, + 'instrument': 'acoustic_grand_piano', + 'density': 10, + 'track_type': 10, + 'ignore': False, + 'selected_bars': [False, False, True, False ], + 'min_polyphony_q': 'POLYPHONY_ANY', + 'max_polyphony_q': 'POLYPHONY_ANY', + 'autoregressive': False, + 'polyphony_hard_limit': 9 + } + ] + } + parami={ + 'tracks_per_step': 1, + 'bars_per_step': 1, + 'model_dim': 4, + 'percentage': 100, + 'batch_size': 1, + 'temperature': 1.0, + 'max_steps': 200, + 'polyphony_hard_limit': 6, + 'shuffle': True, + 'verbose': True, + 'ckpt': ckpt, + 'sampling_seed': -1, + 'mask_top_k': 0 + } + + piece = json.dumps(midi_json_input) + status = json.dumps(valid_status) + param = json.dumps(parami) + callbacks = midigpt.CallbackManager() + max_attempts = 3 + midi_str = midigpt.sample_multi_step(piece, status, param, max_attempts, callbacks) + midi_str=midi_str[0] + midi_json = json.loads(midi_str) + + e = midigpt.ExpressiveEncoder() + e.json_to_midi(midi_str, midi_dest) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..138ea23026a4fb2a77b44a1aab9e36bc1a6409ca --- /dev/null +++ b/setup.py @@ -0,0 +1,67 @@ +#CODE IN PROGRESS, NOT WORKING YET + +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext +import sys +import setuptools + +__version__ = '0.0.1' + +class get_pybind_include(object): + """Helper class to determine the pybind11 include path""" + + def __init__(self, user=False): + self.user = user + + def __str__(self): + import pybind11 + return pybind11.get_include(self.user) + +ext_modules = [ + Extension( + 'midigpt', + ['src/lib.cpp','src/common/data_structures/train_config.cpp','src/dataset_creation/compression/lz4.c', + 'src/dataset_creation/dataset_manipulation/bytes_to_file.cpp'], + include_dirs=[ + get_pybind_include(), + get_pybind_include(user=True) + ], + language='c++' + ), +] + +class BuildExt(build_ext): + """A custom build extension for adding compiler-specific options.""" + c_opts = { + 'msvc': ['/EHsc'], + 'unix': [], + } + + if sys.platform == 'darwin': + c_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7'] + + def build_extensions(self): + ct = self.compiler.compiler_type + opts = self.c_opts.get(ct, []) + if ct == 'unix': + opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) + opts.append('-std=c++20') + elif ct == 'msvc': + opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) + for ext in self.extensions: + ext.extra_compile_args = opts + build_ext.build_extensions(self) + +setup( + name='midigpt', + version=__version__, + author='Jeff Ens, Rafael Arias', + author_email='raa60@sfu.ca', + url='', + description='A Python wrapper for midigpt project', + long_description='', + ext_modules=ext_modules, + install_requires=['pybind11>=2.5.0'], + cmdclass={'build_ext': BuildExt}, + zip_safe=False, +) diff --git a/src/common/data_structures/encoder_config.h b/src/common/data_structures/encoder_config.h new file mode 100644 index 0000000000000000000000000000000000000000..913ae71d9bc4c6b74214a1ff2b146280b8eead71 --- /dev/null +++ b/src/common/data_structures/encoder_config.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include + +namespace data_structures { + class EncoderConfig { + public: + EncoderConfig() { + both_in_one = false; + unquantized = false; + do_multi_fill = false; + use_velocity_levels = false; + use_microtiming = false; + transpose = 0; + resolution = 12; + decode_resolution = resolution; + decode_final = false; + delta_resolution = 1920; + } + + std::map ToJson() { + std::map json_config; + + json_config["both_in_one"] = std::to_string((int)both_in_one); + json_config["unquantized"] = std::to_string((int)unquantized); + json_config["do_multi_fill"] = std::to_string((int)do_multi_fill); + json_config["use_velocity_levels"] = std::to_string((int)use_velocity_levels); + json_config["use_microtiming"] = std::to_string((int)use_microtiming); + json_config["transpose"] = std::to_string(transpose); + json_config["resolution"] = std::to_string(resolution); + json_config["decode_resolution"] = std::to_string(decode_resolution); + json_config["decode_final"] = std::to_string((int)decode_final); + json_config["delta_resolution"] = std::to_string(delta_resolution); + return json_config; + } + + void FromJson(const std::map& json_config) { + try { + both_in_one = (bool)std::stoi(json_config.at("both_in_one")); + unquantized = (bool)std::stoi(json_config.at("unquantized")); + do_multi_fill = (bool)std::stoi(json_config.at("do_multi_fill")); + use_velocity_levels = (bool)std::stoi(json_config.at("use_velocity_levels")); + use_microtiming = (bool)std::stoi(json_config.at("use_microtiming")); + transpose = std::stoi(json_config.at("transpose")); + resolution = std::stoi(json_config.at("resolution")); + decode_resolution = std::stoi(json_config.at("decode_resolution")); + decode_final = (bool)std::stoi(json_config.at("decode_final")); + delta_resolution = std::stoi(json_config.at("delta_resolution")); + } catch (const std::out_of_range& e) { + throw std::invalid_argument("Missing required key in JSON config: " + std::string(e.what())); + } catch (const std::invalid_argument& e) { + throw std::invalid_argument("Invalid value type in JSON config: " + std::string(e.what())); + } + } + + int delta_to_step(int delta, int res) { + if (!use_microtiming) { + return 0; + } else { + return (int)(delta * res / delta_resolution); + } + } + + int step_to_delta(float step, int res) { + if (!use_microtiming) { + return 0; + } else { + return round(delta_resolution * step / res); + } + } + + int step_to_delta(int step, int res) { + if (!use_microtiming) { + return 0; + } else { + return round(delta_resolution * step / res); + } + } + + bool both_in_one; + bool unquantized; + bool do_multi_fill; + bool use_velocity_levels; + bool use_microtiming; + int transpose; + int resolution; + int decode_resolution; + bool decode_final; + int delta_resolution; + std::set> multi_fill; + }; +} \ No newline at end of file diff --git a/src/common/data_structures/token_sequence.h b/src/common/data_structures/token_sequence.h new file mode 100644 index 0000000000000000000000000000000000000000..a45d05cf7e65a60b43a1958362aaf78a6306ab0e --- /dev/null +++ b/src/common/data_structures/token_sequence.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include "midi.pb.h" +#include "../encoder/representation.h" + +namespace data_structures { + + class TokenSequence { + public: + TokenSequence(const std::shared_ptr &rep) { + bar_num = 0; + track_num = 0; + } + void push_back( int token ) { + tokens.push_back( token ); + } + + void insert( std::vector &tokens) { + for (auto token : tokens) { + push_back(token); + } + } + + void on_track_start(midi::Piece *x, const std::shared_ptr &rep) { + track_num++; + bar_num = 0; + } + void on_bar_start(midi::Piece *x, const std::shared_ptr &rep) { + + bar_num++; + } + int bar_num; + int track_num; + + std::vector tokens; + }; +} diff --git a/src/common/data_structures/track_type.h b/src/common/data_structures/track_type.h new file mode 100644 index 0000000000000000000000000000000000000000..e65fc57bdbf5c8c5378783e800c8345ade904d80 --- /dev/null +++ b/src/common/data_structures/track_type.h @@ -0,0 +1,20 @@ +#pragma once + +#include "midi.pb.h" + +// START OF NAMESPACE +namespace data_structures { + +std::map TRACK_TYPE_IS_DRUM = { + {midi::AUX_DRUM_TRACK, true}, + {midi::AUX_INST_TRACK, false}, + {midi::STANDARD_TRACK, false}, + {midi::STANDARD_DRUM_TRACK, true}, +}; + +bool is_drum_track(int tt) { + return TRACK_TYPE_IS_DRUM[static_cast(tt)]; +} + +} +// END OF NAMESPACE diff --git a/src/common/data_structures/train_config.cpp b/src/common/data_structures/train_config.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f62058a6be8deb0a0cc68bfdc5c087fab335fa7 --- /dev/null +++ b/src/common/data_structures/train_config.cpp @@ -0,0 +1,46 @@ +#include "train_config.h" + + +namespace data_structures { + + TrainConfig::TrainConfig() { + num_bars = 4; + min_tracks = 1; + max_tracks = 12; + max_mask_percentage = 0.75; + use_microtiming = false; + microtiming = 0.9; + no_max_length = false; + resolution = 12; + delta_resolution = 1920; + decode_resolution = delta_resolution; + } + + std::map TrainConfig::ToJson() { + std::map json_config; + json_config["num_bars"] = std::to_string(num_bars); + json_config["min_tracks"] = std::to_string(min_tracks); + json_config["max_tracks"] = std::to_string(max_tracks); + json_config["max_mask_percentage"] = std::to_string(max_mask_percentage); + json_config["use_microtiming"] = std::to_string((int)use_microtiming); + json_config["microtiming"] = std::to_string(microtiming); + json_config["no_max_length"] = std::to_string((int)no_max_length); + json_config["resolution"] = std::to_string(resolution); + json_config["decode_resolution"] = std::to_string(decode_resolution); + json_config["delta_resolution"] = std::to_string(delta_resolution); + return json_config; + } + + void TrainConfig::FromJson(std::map& json_config) { + num_bars = stoi(json_config["num_bars"]); + min_tracks = stoi(json_config["min_tracks"]); + max_tracks = stoi(json_config["max_tracks"]); + max_mask_percentage = stof(json_config["max_mask_percentage"]); + microtiming = stoi(json_config["microtiming"]); + use_microtiming = (bool)stoi(json_config["use_microtiming"]); + no_max_length = (bool)stoi(json_config["no_max_length"]); + resolution = stoi(json_config["resolution"]); + decode_resolution = stoi(json_config["decode_resolution"]); + delta_resolution = stoi(json_config["delta_resolution"]); + } +} \ No newline at end of file diff --git a/src/common/data_structures/train_config.h b/src/common/data_structures/train_config.h new file mode 100644 index 0000000000000000000000000000000000000000..c63be5b0a50cfcd9eba66b6389fc9d9bb3bb78b0 --- /dev/null +++ b/src/common/data_structures/train_config.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +// START OF NAMESPACE +namespace data_structures { + +class TrainConfig { +public: + int num_bars; + int min_tracks; + int max_tracks; + float max_mask_percentage; + bool use_microtiming; + float microtiming; + bool no_max_length; + int resolution; + int decode_resolution; + int delta_resolution; + + TrainConfig(); + + std::map ToJson(); + void FromJson(std::map& json_config); +}; + +} +// END OF NAMESPACE \ No newline at end of file diff --git a/src/common/data_structures/verbosity.h b/src/common/data_structures/verbosity.h new file mode 100644 index 0000000000000000000000000000000000000000..6c03249946f63928a059a4c62665c0cd159358e1 --- /dev/null +++ b/src/common/data_structures/verbosity.h @@ -0,0 +1,55 @@ +// control verbosity levels in the code to make things cleaner + +#pragma once + +#include + +namespace data_structures { + +enum VERBOSITY_LEVEL { + VERBOSITY_LEVEL_QUIET, + VERBOSITY_LEVEL_VERBOSE, + VERBOSITY_LEVEL_DEBUG, + VERBOSITY_LEVEL_TRACE +}; + +VERBOSITY_LEVEL GLOBAL_VERBOSITY_LEVEL = VERBOSITY_LEVEL_QUIET; + +inline void setGlobalVerbosityLevel(VERBOSITY_LEVEL vl) { + GLOBAL_VERBOSITY_LEVEL = vl; +} + +template +std::string to_str(const T& value){ + std::ostringstream tmp_str; + tmp_str << value; + return tmp_str.str(); +} + +template +std::string to_str(const T& value, const Args& ... args){ + return to_str(value) + to_str(args...); +} + +template +inline void LOGGER(T x) { + if (GLOBAL_VERBOSITY_LEVEL >= VERBOSITY_LEVEL_VERBOSE) { + std::cout << x << std::endl; + } +} + +template +inline void LOGGER(VERBOSITY_LEVEL vl, T x) { + if (vl <= GLOBAL_VERBOSITY_LEVEL) { + std::cout << x << std::endl; + } +} + +template +inline void LOGGER(T x, bool newline) { + if (GLOBAL_VERBOSITY_LEVEL >= VERBOSITY_LEVEL_VERBOSE) { + std::cout << x; + } +} + +} \ No newline at end of file diff --git a/src/common/encoder/attribute_control.h b/src/common/encoder/attribute_control.h new file mode 100644 index 0000000000000000000000000000000000000000..9581b3a27f770984b05267c104ccc5dd33247cff --- /dev/null +++ b/src/common/encoder/attribute_control.h @@ -0,0 +1,1370 @@ +#pragma once + +#include +#include +#include + +#include "representation.h" + +#include "../../common/data_structures/token_sequence.h" + +namespace encoder { + +enum ATTRIBUTE_CONTROL_LEVEL { + ATTRIBUTE_CONTROL_LEVEL_PIECE, + ATTRIBUTE_CONTROL_LEVEL_TRACK, + ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT, + ATTRIBUTE_CONTROL_LEVEL_BAR +}; + +enum ATTRIBUTE_CONTROL_TRACK_TYPE { + ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT, + ATTRIBUTE_CONTROL_TRACK_TYPE_DRUM, + ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT_AND_DRUM, + ATTRIBUTE_CONTROL_TRACK_TYPE_NONE +}; + +template +int protobuf_get_field_value(T *message, const std::string &feature_name) { + const google::protobuf::FieldDescriptor *fd = message->GetDescriptor()->FindFieldByName(feature_name); + if (fd == NULL) { + throw std::runtime_error("INVALID FIELD NAME"); + } + if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_INT32) { + return message->GetReflection()->GetInt32(*message, fd); + } + if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_ENUM) { + return message->GetReflection()->GetEnumValue(*message, fd); + } + std::cout << "field name: " << feature_name << std::endl; + throw std::runtime_error("INVALID FIELD TYPE"); +} + +template +U protobuf_get_field(const T *message, const std::string &feature_name) { + const google::protobuf::FieldDescriptor *fd = message->GetDescriptor()->FindFieldByName(feature_name); + if (fd == NULL) { + throw std::runtime_error("INVALID FIELD NAME"); + } + if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_INT32) { + return message->GetReflection()->GetInt32(*message, fd); + } + else if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_BOOL) { + return message->GetReflection()->GetBool(*message, fd); + } + else if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_FLOAT) { + return message->GetReflection()->GetFloat(*message, fd); + } + else if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_ENUM) { + return message->GetReflection()->GetEnumValue(*message, fd); + } + else { + std::cout << "field name: " << feature_name << std::endl; + throw std::runtime_error("INVALID FIELD TYPE"); + } +} + +template +void protobuf_set_field(T *message, const std::string &feature_name, U value) { + const google::protobuf::FieldDescriptor *fd = message->GetDescriptor()->FindFieldByName(feature_name); + if (fd == NULL) { + throw std::runtime_error("INVALID FIELD NAME"); + } + if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_INT32) { + message->GetReflection()->SetInt32(message, fd, value); + } + else if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_BOOL) { + message->GetReflection()->SetBool(message, fd, value); + } + else if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_FLOAT) { + message->GetReflection()->SetFloat(message, fd, value); + } + else if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_ENUM) { + message->GetReflection()->SetEnumValue(message, fd, value); + } + else { + std::cout << "field name: " << feature_name << std::endl; + throw std::runtime_error("INVALID FIELD TYPE"); + } +} + +class TOKEN_COUNTER { +public: + TOKEN_COUNTER(midi::TOKEN_TYPE tt) { + token_type = tt; + token_count = 0; + } + ~TOKEN_COUNTER() {} + std::tuple update(std::shared_ptr rep, int token) { + bool has_changed = (rep->get_token_type(token) == token_type); + if (has_changed) { + token_count++; + } + return std::make_tuple(token_count, has_changed); + } + void override(int count) { + token_count = count; + } + midi::TOKEN_TYPE token_type; + int token_count; +}; + +class TOKEN_LABELER { +public: + TOKEN_LABELER() { + bar_counter = std::make_unique(midi::TOKEN_BAR); + track_counter = std::make_unique(midi::TOKEN_TRACK); + } + ~TOKEN_LABELER() {} + std::tuple update(std::shared_ptr rep, int token) { + auto [track_count, track_count_changed] = track_counter->update(rep,token); + if (track_count_changed) { + bar_counter->override(0); + } + auto [bar_count, bar_count_changed] = bar_counter->update(rep,token); + return std::make_tuple(std::max(track_count-1,0),std::max(bar_count-1,0)); + } + std::unique_ptr bar_counter; + std::unique_ptr track_counter; +}; + +// basic implementation +std::vector> PitchProbabilityEmbedding(midi::Piece *x, std::shared_ptr rep, std::vector &tokens) { + + // first calculate per track pitch probabilities + std::vector> probs; + for (const auto &track : x->tracks()) { + double total = 0; + std::vector prob(128, 0.0); + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity() > 0) { + prob[x->events(event_index).pitch()]++; + total++; + } + } + } + if (total > 0) { + for (int i=0; i<128; i++) { + prob[i] /= total; + } + } + probs.push_back(prob); + } + + std::vector> embeds; + auto tl = TOKEN_LABELER(); + for (const auto &token : tokens) { + auto [track_index, bar_index] = tl.update(rep, token); + if (track_index >= (int)probs.size()) { + throw std::runtime_error("INVALID TRACK INDEX DURING PitchProbabilityEmbedding()"); + } + if (track_index < 0) { + throw std::runtime_error("INVALID TRACK INDEX < 0 DURING PitchProbabilityEmbedding()"); + } + embeds.push_back(probs[track_index]); + } + return embeds; +} + +double map(double x, double in_min, double in_max, double out_min, double out_max) { + return (x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min; +} + +class ATTRIBUTE_CONTROL { +public: + + ATTRIBUTE_CONTROL_LEVEL control_level; + ATTRIBUTE_CONTROL_TRACK_TYPE track_type; + std::vector> token_types; + std::vector> token_types_v2; + std::vector> token_types_v3; + bool precompute_on_piece; + + virtual ~ATTRIBUTE_CONTROL () {} + + virtual void compute_piece_features(midi::Piece *x, midi::PieceFeatures *pf) { + // this function is responsible for computing the features that are needed for + // this form of attribute control + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE compute_piece_features()"); + } + + virtual void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + // this function is responsible for computing the features that are needed for + // this form of attribute control + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE compute_track_features()"); + } + + virtual void compute_bar_features(midi::Piece *x, int track_num, int bar_num, midi::BarFeatures *bf) { + // this function is responsible for computing the features that are needed for + // this form of attribute control + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE compute_bar_features()"); + } + + virtual void append_piece_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::PieceFeatures *pf) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + tokens->push_back( rep->encode(std::get<0>(fn), protobuf_get_field_value(pf, std::get<2>(fn))) ); + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL MUST DEFINE append_piece_tokens()"); + } + } + + virtual void append_track_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::TrackFeatures *tf) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + tokens->push_back( rep->encode(std::get<0>(fn), protobuf_get_field_value(tf, std::get<2>(fn))) ); + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL MUST DEFINE append_track_tokens()"); + } + } + + virtual void append_bar_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::BarFeatures *bf) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + tokens->push_back( rep->encode(std::get<0>(fn), protobuf_get_field_value(bf, std::get<2>(fn))) ); + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL MUST DEFINE append_bar_tokens()"); + } + } + + virtual void set_piece_mask(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::Status *piece) { + // this function sets the appropriate token mask for sampling to control which attribute is selected + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE set_piece_mask"); + } + + virtual void set_track_mask(const std::shared_ptr &rep, std::vector &mask, midi::StatusTrack *track) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + rep->set_mask(std::get<0>(fn), {protobuf_get_field_value(track, std::get<2>(fn))-1}, mask, 1); + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE set_track_mask"); + } + } + + virtual void set_bar_mask(const std::shared_ptr &rep, std::vector &mask, midi::StatusBar *bar) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + rep->set_mask(std::get<0>(fn), {protobuf_get_field_value(bar, std::get<2>(fn))-1}, mask, 1); + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE set_bar_mask"); + } + } + + virtual void override_track_feature(midi::TrackFeatures *tf, midi::StatusTrack *track) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + auto value = protobuf_get_field_value(track, std::get<2>(fn)); + if (value > 0) { + protobuf_set_field(tf, std::get<2>(fn), value - 1); // copy value from status to piece + } + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE override_track_feature"); + } + } + + virtual void override_bar_feature(midi::BarFeatures *bf, midi::StatusBar *bar) { + if (token_types_v2.size() > 0) { + for (const auto &fn : token_types_v2) { + auto value = protobuf_get_field_value(bar, std::get<2>(fn)); + if (value > 0) { + protobuf_set_field(bf, std::get<2>(fn), value - 1); // copy value from status to piece + } + } + } + else { + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE override_bar_feature"); + } + } + + void override_track_level_features(midi::Piece *x, midi::Status *s) { + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::TrackFeatures *tf = util_protobuf::GetTrackFeatures(x,track_num); + midi::StatusTrack st = s->tracks(track_num); + override_track_feature(tf, &st); + } + } + + void override_bar_level_features(midi::Piece *x, midi::Status *s) { + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::Track *track = x->mutable_tracks(track_num); + midi::StatusTrack st = s->tracks(track_num); + for (int bar_num=0; bar_numbars_size(); bar_num++) { + midi::BarFeatures *bf = util_protobuf::GetBarFeatures(track, bar_num); + midi::StatusBar sb = st.bars(bar_num); + override_bar_feature(bf, &sb); + } + } + } + + void override_features(midi::Piece *x, midi::Status *s) { + switch(control_level) { + case ATTRIBUTE_CONTROL_LEVEL_PIECE: + throw std::runtime_error("CANNOT OVERRIDE PIECE LEVEL FEATURES"); + break; + case ATTRIBUTE_CONTROL_LEVEL_TRACK: + override_track_level_features(x,s); + break; + case ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT: + override_track_level_features(x,s); + break; + case ATTRIBUTE_CONTROL_LEVEL_BAR: + override_bar_level_features(x,s); + break; + default: + throw std::runtime_error("INVALID ATTRIBUTE CONTROL LEVEL"); + } + } + + void compute_piece_level_features(midi::Piece *x) { + midi::PieceFeatures *pf = util_protobuf::GetPieceFeatures(x); + compute_piece_features(x, pf); + } + + void compute_track_level_features(midi::Piece *x) { + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::TrackFeatures *tf = util_protobuf::GetTrackFeatures(x,track_num); + compute_track_features(x, track_num, tf); + } + } + + void compute_bar_level_features(midi::Piece *x) { + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::Track *track = x->mutable_tracks(track_num); + for (int bar_num=0; bar_numbars_size(); bar_num++) { + midi::BarFeatures *bf = util_protobuf::GetBarFeatures(track, bar_num); + compute_bar_features(x, track_num, bar_num, bf); + } + } + } + + void compute_features(midi::Piece *x) { + switch(control_level) { + case ATTRIBUTE_CONTROL_LEVEL_PIECE: + compute_piece_level_features(x); + break; + case ATTRIBUTE_CONTROL_LEVEL_TRACK: + compute_track_level_features(x); + break; + case ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT: + compute_track_level_features(x); + break; + case ATTRIBUTE_CONTROL_LEVEL_BAR: + compute_bar_level_features(x); + break; + default: + throw std::runtime_error("INVALID ATTRIBUTE CONTROL LEVEL"); + } + } + + virtual double evaluate_track_feature(midi::Piece *x, int track_num, midi::TrackFeatures *tf, midi::StatusTrack *st) { + throw std::runtime_error("ATTRIBUTE CONTROL CLASS MUST DEFINE evaluate_track_feature()"); + } + + std::vector evaluate_track_feature_py(std::string &piece_json, std::string &status_json) { + midi::Piece x; + midi::Status s; + util_protobuf::string_to_protobuf(piece_json, &x); + util_protobuf::string_to_protobuf(status_json, &s); + std::vector output; + for (int i=0; i get_token_types() { + std::vector token_types_list; + for (const auto &ttd : token_types) { + token_types_list.push_back(std::get<0>(ttd)); + } + return token_types_list; + } + + int get_token_domain_size(midi::TOKEN_TYPE tt) { + for (const auto &ttd : token_types) { + if (std::get<0>(ttd) == tt) { + return std::get<1>(ttd); + } + } + throw std::runtime_error("ATTRIBUTE_CONTROL::get_token_domain_size() : TOKEN TYPE NOT FOUND"); + } + + virtual TOKEN_DOMAIN get_token_domain(midi::TOKEN_TYPE tt) { + return TOKEN_DOMAIN(get_token_domain_size(tt)); + } + + bool is_track_control() { + return (control_level == ATTRIBUTE_CONTROL_LEVEL_TRACK) || (control_level == ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT); + } + + bool is_bar_control() { + return (control_level == ATTRIBUTE_CONTROL_LEVEL_BAR); + } + + + // get the enum domain for the attribute control in status track + std::map> get_status_track_enum_domain() { + if (token_types_v2.size() == 0) { + throw std::runtime_error("STATUS TRACK FIELD NAME NOT SPECIFIED"); + } + midi::StatusTrack st; + midi::StatusBar sb; + const google::protobuf::Descriptor *descriptor = is_bar_control() ? sb.GetDescriptor() : st.GetDescriptor(); + if (descriptor == NULL) { + throw std::runtime_error("INVALID DESCRIPTOR"); + } + std::map> output; + for (const auto &fn : token_types_v2) { + //std::cout << "FIELD NAME: " << std::get<2>(fn) << std::endl; + auto field_name = std::get<2>(fn); + const google::protobuf::FieldDescriptor *field = descriptor->FindFieldByName(field_name); + if (field == NULL) { + throw std::runtime_error("INVALID FIELD NAME"); + } + auto enum_descriptor = field->enum_type(); + if (enum_descriptor == NULL) { + throw std::runtime_error("INVALID ENUM TYPE"); + } + for (int i=0; ivalue_count(); i++) { + output[field_name].push_back(enum_descriptor->value(i)->name()); + } + } + return output; + } + + std::map> get_status_enum_mapping() { + if (token_types_v2.size() == 0) { + throw std::runtime_error("STATUS BAR FIELD NAME NOT SPECIFIED"); + } + midi::StatusTrack st; + midi::StatusBar sb; + const google::protobuf::Descriptor *descriptor = is_bar_control() ? sb.GetDescriptor() : st.GetDescriptor(); + if (descriptor == NULL) { + throw std::runtime_error("INVALID DESCRIPTOR"); + } + std::map> output; + for (const auto &fn : token_types_v2) { + auto field_name = std::get<2>(fn); + const google::protobuf::FieldDescriptor *field = descriptor->FindFieldByName(field_name); + if (field == NULL) { + throw std::runtime_error("INVALID FIELD NAME"); + } + auto enum_descriptor = field->enum_type(); + for (int i=0; ivalue_count(); i++) { + output[field_name][enum_descriptor->value(i)->name()] = i; + } + } + return output; + } + +}; + +// ================================================ +// ================================================ +// ATTRIBUTE CONTROLS +// ================================================ +// ================================================ + +class TrackLevelOnsetPolyphony : public ATTRIBUTE_CONTROL { +public: + + TrackLevelOnsetPolyphony() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT_AND_DRUM; + token_types = { + {midi::TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MIN, 6}, + {midi::TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MAX, 6} + }; + token_types_v2 = { + {midi::TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MIN, 6, "onset_polyphony_min"}, + {midi::TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MAX, 6, "onset_polyphony_max"} + }; + } + ~TrackLevelOnsetPolyphony() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + const auto track = x->tracks(track_num); + tf->mutable_attribute_control_distributions()->clear_onset_polyphony(); + + int bar_start = 0; + std::map concurrent_onsets; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + concurrent_onsets[bar_start + x->events(event_index).time()] += 1; + } + } + bar_start += x->resolution() * bar.internal_beat_length(); + } + + int polyphony_min = INT_MAX; + int polyphony_max = INT_MIN; + for (const auto &kv : concurrent_onsets) { + if (kv.second < polyphony_min) { + polyphony_min = kv.second; + } + if (kv.second > polyphony_max) { + polyphony_max = kv.second; + } + tf->mutable_attribute_control_distributions()->add_onset_polyphony(kv.second); // for evaluation + } + + tf->set_onset_polyphony_min( util_protobuf::clip(polyphony_min, 1, get_token_domain_size(midi::TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MIN)) - 1 ); + tf->set_onset_polyphony_max( util_protobuf::clip(polyphony_max, 1, get_token_domain_size(midi::TOKEN_TRACK_LEVEL_ONSET_POLYPHONY_MAX)) - 1 ); + } + + + double evaluate_track_feature(midi::Piece *x, int track_num, midi::TrackFeatures *tf, midi::StatusTrack *st) { + compute_track_features(x, track_num, tf); + auto mapping = get_status_enum_mapping(); + auto domain = get_status_track_enum_domain(); + double range_min = mapping["onset_polyphony_min"][domain["onset_polyphony_min"][protobuf_get_field_value(st, "onset_polyphony_min")]]; + double range_max = mapping["onset_polyphony_max"][domain["onset_polyphony_max"][protobuf_get_field_value(st, "onset_polyphony_max")]]; + double score = 0.0; + double total = 0.0; + for (const auto value : tf->attribute_control_distributions().onset_polyphony()) { + score += (range_min <= value) && (value <= range_max); + total += 1; + } + return score / total; + } +}; + + +class TrackLevelNoteDuration : public ATTRIBUTE_CONTROL { +public: + + TrackLevelNoteDuration() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT; + token_types = { + {midi::TOKEN_CONTAINS_NOTE_DURATION_THIRTY_SECOND, 2}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_SIXTEENTH, 2}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_EIGHTH, 2}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_QUARTER, 2}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_HALF, 2}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_WHOLE, 2} + }; + token_types_v2 = { + {midi::TOKEN_CONTAINS_NOTE_DURATION_THIRTY_SECOND, 2, "contains_note_duration_thirty_second"}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_SIXTEENTH, 2, "contains_note_duration_sixteenth"}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_EIGHTH, 2, "contains_note_duration_eighth"}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_QUARTER, 2, "contains_note_duration_quarter"}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_HALF, 2, "contains_note_duration_half"}, + {midi::TOKEN_CONTAINS_NOTE_DURATION_WHOLE, 2, "contains_note_duration_whole"} + }; + } + ~TrackLevelNoteDuration() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + // add in the note duration distribution for testing at some point ... + const auto track = x->tracks(track_num); + tf->mutable_attribute_control_distributions()->note_duration(); + + int max_tick = 0; + std::vector notes = util_protobuf::TrackEventsToNotes(x, track_num, &max_tick); + + // get note durations + std::vector durations; + for (const auto ¬e : notes) { + double d = note.end() - note.start(); + int duration_level = (int)util_protobuf::clip(util_protobuf::midigpt_log2(std::max(d / 3., 1e-6)), 0., 5.); // assume resolution==24 + durations.push_back(duration_level); + tf->mutable_attribute_control_distributions()->add_note_duration(duration_level); // for evaluation + } + + // see which categories are used + std::vector used_categories(6, 0); + for (const auto &d : durations) { + used_categories[d] = 1; + } + + // add features + tf->set_contains_note_duration_thirty_second(used_categories[0]); + tf->set_contains_note_duration_sixteenth(used_categories[1]); + tf->set_contains_note_duration_eighth(used_categories[2]); + tf->set_contains_note_duration_quarter(used_categories[3]); + tf->set_contains_note_duration_half(used_categories[4]); + tf->set_contains_note_duration_whole(used_categories[5]); + } + + double evaluate_track_feature(midi::Piece *x, int track_num, midi::TrackFeatures *tf, midi::StatusTrack *st) { + compute_track_features(x, track_num, tf); + std::map mapping = { + {0,"contains_note_duration_thirty_second"}, + {1,"contains_note_duration_sixteenth"}, + {2,"contains_note_duration_eighth"}, + {3,"contains_note_duration_quarter"}, + {4,"contains_note_duration_half"}, + {5,"contains_note_duration_whole"} + }; + double score = 0.0; + double total = 0.0; + const google::protobuf::Reflection *reflection = st->GetReflection(); + const google::protobuf::Descriptor *descriptor = st->GetDescriptor(); + for (const auto value : tf->attribute_control_distributions().note_duration()) { + const google::protobuf::FieldDescriptor *fd = descriptor->FindFieldByName(mapping[value]); + score += (reflection->GetEnumValue(*st, fd) == midi::BOOLEAN_TRUE); + total += 1; + } + return score / total; + } +}; + +class TrackLevelOnsetDensity : public ATTRIBUTE_CONTROL { +public: + + TrackLevelOnsetDensity() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT_AND_DRUM; + token_types = { + {midi::TOKEN_TRACK_LEVEL_ONSET_DENSITY_MIN, 18}, + {midi::TOKEN_TRACK_LEVEL_ONSET_DENSITY_MAX, 18} + }; + token_types_v2 = { + {midi::TOKEN_TRACK_LEVEL_ONSET_DENSITY_MIN, 18, "onset_density_min"}, + {midi::TOKEN_TRACK_LEVEL_ONSET_DENSITY_MAX, 18, "onset_density_max"} + }; + } + ~TrackLevelOnsetDensity() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + const auto track = x->tracks(track_num); + tf->mutable_attribute_control_distributions()->clear_onset_density(); + + std::vector unique_onsets_per_bar; + for (const auto &bar : track.bars()) { + std::set unique_onsets; + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + unique_onsets.insert(x->events(event_index).time()); + } + } + unique_onsets_per_bar.push_back( util_protobuf::clip((int)unique_onsets.size(), 0, get_token_domain_size(midi::TOKEN_TRACK_LEVEL_ONSET_DENSITY_MIN)-1) ); // 18 classes + } + + int onsets_min = INT_MAX; + int onsets_max = INT_MIN; + for (const auto &x : unique_onsets_per_bar) { + if (x < onsets_min) { + onsets_min = x; + } + if (x > onsets_max) { + onsets_max = x; + } + tf->mutable_attribute_control_distributions()->add_onset_density(x); // for evaluation + } + + tf->set_onset_density_min( onsets_min ); + tf->set_onset_density_max( onsets_max ); + } + + double evaluate_track_feature(midi::Piece *x, int track_num, midi::TrackFeatures *tf, midi::StatusTrack *st) { + compute_track_features(x, track_num, tf); + auto mapping = get_status_enum_mapping(); + auto domain = get_status_track_enum_domain(); + double range_min = mapping["onset_density_min"][domain["onset_density_min"][protobuf_get_field_value(st, "onset_density_min")]]; + double score = 0.0; + double total = 0.0; + for (const auto value : tf->attribute_control_distributions().onset_density()) { + score += abs(value - range_min); + total += 1; + } + return score / total; + } +}; + +class BarLevelOnsetPolyphony : public ATTRIBUTE_CONTROL { +public: + + BarLevelOnsetPolyphony() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_BAR; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT_AND_DRUM; + token_types = { + {midi::TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MIN, 6}, + {midi::TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MAX, 6} + }; + token_types_v2 = { + {midi::TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MIN, 6, "onset_polyphony_min"}, + {midi::TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MAX, 6, "onset_polyphony_max"} + }; + } + ~BarLevelOnsetPolyphony() {} + + void compute_bar_features(midi::Piece *x, int track_num, int bar_num, midi::BarFeatures *bf) { + const auto track = x->tracks(track_num); + const auto bar = track.bars(bar_num); + + std::map concurrent_onsets; + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + concurrent_onsets[x->events(event_index).time()] += 1; + } + } + + // get the min and max of concurrent onsets + int polyphony_min = INT_MAX; + int polyphony_max = INT_MIN; + for (const auto &kv : concurrent_onsets) { + if (kv.second < polyphony_min) { + polyphony_min = kv.second; + } + if (kv.second > polyphony_max) { + polyphony_max = kv.second; + } + } + + bf->set_onset_polyphony_min( util_protobuf::clip( + polyphony_min, 1, get_token_domain_size(midi::TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MIN)) - 1 ); + bf->set_onset_polyphony_max( util_protobuf::clip( + polyphony_max, 1, get_token_domain_size(midi::TOKEN_BAR_LEVEL_ONSET_POLYPHONY_MAX)) - 1 ); + } +}; + +class BarLevelOnsetDensity : public ATTRIBUTE_CONTROL { +public: + + BarLevelOnsetDensity() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_BAR; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT_AND_DRUM; + token_types = { + {midi::TOKEN_BAR_LEVEL_ONSET_DENSITY, 18} + }; + token_types_v2 = { + {midi::TOKEN_BAR_LEVEL_ONSET_DENSITY, 18, "onset_density"} + }; + } + ~BarLevelOnsetDensity() {} + + void compute_bar_features(midi::Piece *x, int track_num, int bar_num, midi::BarFeatures *bf) { + const auto track = x->tracks(track_num); + const auto bar = track.bars(bar_num); + + std::set unique_onsets; + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + unique_onsets.insert(x->events(event_index).time()); + } + } + + bf->set_onset_density(util_protobuf::clip( + (int)unique_onsets.size(), 0, get_token_domain_size(midi::TOKEN_BAR_LEVEL_ONSET_DENSITY)-1)); + } +}; + +class PolyphonyQuantile : public ATTRIBUTE_CONTROL { +public: + + PolyphonyQuantile() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT; + token_types = { + {midi::TOKEN_MIN_POLYPHONY, 10}, + {midi::TOKEN_MAX_POLYPHONY, 10} + }; + token_types_v2 = { + {midi::TOKEN_MIN_POLYPHONY, 10, "min_polyphony_q"}, + {midi::TOKEN_MAX_POLYPHONY, 10, "max_polyphony_q"} + }; + } + ~PolyphonyQuantile() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + const auto track = x->tracks(track_num); + tf->mutable_attribute_control_distributions()->clear_polyphony_quantile(); + + int max_tick = 0; + std::vector notes = util_protobuf::TrackEventsToNotes(x, track_num, &max_tick); + int nonzero_count = 0; + double count = 0; + std::vector flat_roll(max_tick, 0); + for (const auto ¬e : notes) { + for (int t = note.start(); t < std::min(note.end(), max_tick - 1); t++) { + if (flat_roll[t] == 0) { + nonzero_count += 1; + } + flat_roll[t]++; + count++; + } + } + + std::vector nz; + for (const auto &x : flat_roll) { + if (x > 0) { + nz.push_back(x); + tf->mutable_attribute_control_distributions()->add_polyphony_quantile(x); // for evaluation + } + } + + // get quantiles and add to track features + std::vector polyphony_qs = util_protobuf::quantile(nz, { .15,.85 }); + tf->set_min_polyphony_q( util_protobuf::clip(polyphony_qs[0], 1, get_token_domain_size(midi::TOKEN_MIN_POLYPHONY)) - 1 ); + tf->set_max_polyphony_q( util_protobuf::clip(polyphony_qs[1], 1, get_token_domain_size(midi::TOKEN_MAX_POLYPHONY)) - 1 ); + } +}; + +class NoteDurationQuantile : public ATTRIBUTE_CONTROL { +public: + + NoteDurationQuantile() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT; + token_types = { + {midi::TOKEN_MIN_NOTE_DURATION, 6}, + {midi::TOKEN_MAX_NOTE_DURATION, 6} + }; + token_types_v2 = { + {midi::TOKEN_MIN_NOTE_DURATION, 6, "min_note_duration_q"}, + {midi::TOKEN_MAX_NOTE_DURATION, 6, "max_note_duration_q"} + }; + } + ~NoteDurationQuantile() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + const auto track = x->tracks(track_num); + tf->mutable_attribute_control_distributions()->clear_note_duration_quantile(); + + int max_tick = 0; + std::vector notes = util_protobuf::TrackEventsToNotes(x, track_num, &max_tick); + + // get note durations + std::vector durations; + for (const auto ¬e : notes) { + double d = note.end() - note.start(); + int duration_level = (int)util_protobuf::clip(util_protobuf::midigpt_log2(std::max(d / 3., 1e-6)) + 1, 0., (double)get_token_domain_size(midi::TOKEN_MIN_NOTE_DURATION)-1.); + durations.push_back(duration_level); + tf->mutable_attribute_control_distributions()->add_note_duration_quantile(duration_level); // for evaluation + } + + // get quantiles and add to track features + std::vector dur_qs = util_protobuf::quantile(durations, { .15,.85 }); + tf->set_min_note_duration_q(dur_qs[0]); + tf->set_max_note_duration_q(dur_qs[1]); + } +}; + +class NoteDensity : public ATTRIBUTE_CONTROL { +public: + + NoteDensity() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_DRUM; + token_types = { + {midi::TOKEN_DENSITY_LEVEL, 10} + }; + token_types_v2 = { + {midi::TOKEN_DENSITY_LEVEL, 10, "note_density_level"} + }; + } + ~NoteDensity() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + const auto track = x->tracks(track_num); + + // calculate average notes per bar + int num_notes = 0; + int bar_num = 0; + std::set valid_bars; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + valid_bars.insert(bar_num); + num_notes++; + } + } + bar_num++; + } + int num_bars = std::max((int)valid_bars.size(), 1); + double av_notes_fp = (double)num_notes / num_bars; + int av_notes = round(av_notes_fp); + + // calculate the density bin + int qindex = track.instrument(); + int bin = 0; + + if (data_structures::is_drum_track(track.track_type())) { + qindex = 128; + } + while (av_notes > enums::DENSITY_QUANTILES[qindex][bin]) { + bin++; + } + + tf->set_note_density_level(bin); + tf->set_note_density_value(av_notes_fp); + } +}; + +template +T median(std::vector &xs) { + std::sort(xs.begin(), xs.end()); + return xs[xs.size() / 2]; +} + +class PitchRange : public ATTRIBUTE_CONTROL { +public: + + PitchRange() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT; + token_types = { + {midi::TOKEN_TRACK_LEVEL_PITCH_RANGE_MIN, 128}, + {midi::TOKEN_TRACK_LEVEL_PITCH_RANGE_MAX, 128} + }; + } + ~PitchRange() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + const auto track = x->tracks(track_num); + int min_pitch = 127; + int max_pitch = 0; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + int pitch = x->events(event_index).pitch(); + if (pitch < min_pitch) { + min_pitch = pitch; + } + if (pitch > max_pitch) { + max_pitch = pitch; + } + } + } + } + tf->set_min_pitch(min_pitch); + tf->set_max_pitch(max_pitch); + } + + void append_track_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::TrackFeatures *tf) { + tokens->push_back( rep->encode(midi::TOKEN_TRACK_LEVEL_PITCH_RANGE_MIN, tf->min_pitch()) ); + tokens->push_back( rep->encode(midi::TOKEN_TRACK_LEVEL_PITCH_RANGE_MAX, tf->max_pitch()) ); + } + + void set_track_mask(const std::shared_ptr &rep, std::vector &mask, midi::StatusTrack *track) { + rep->set_mask(midi::TOKEN_TRACK_LEVEL_PITCH_RANGE_MIN, {track->min_pitch()}, mask, 1); + rep->set_mask(midi::TOKEN_TRACK_LEVEL_PITCH_RANGE_MAX, {track->max_pitch()}, mask, 1); + } +}; + +class Genre : public ATTRIBUTE_CONTROL { +public: + + Genre() { + precompute_on_piece = false; + control_level = ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT; + track_type = ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT_AND_DRUM; + token_types = { + {midi::TOKEN_GENRE, static_cast(midi::GENRE_MUSICMAP_NONE)} + }; + token_types_v2 = { + {midi::TOKEN_GENRE, static_cast(midi::GENRE_MUSICMAP_NONE), "genre"} + }; + } + ~Genre() {} + + void compute_track_features(midi::Piece *x, int track_num, midi::TrackFeatures *tf) { + auto metadata_label = x->internal_metadata_labels().genre(); + if (metadata_label == midi::GENRE_MUSICMAP_ANY) { + metadata_label = midi::GENRE_MUSICMAP_NONE; + } + tf->set_genre(static_cast(metadata_label)-1); + } + + // override get token domain to get the different strings + TOKEN_DOMAIN get_token_domain(midi::TOKEN_TYPE tt) { + if (tt != midi::TOKEN_GENRE) { + throw std::runtime_error("Genre::get_token_domain: invalid token type"); + } + std::vector domain; + for (int i=0; iFindValueByNumber(static_cast(i+1))->name(); + domain.push_back(name); + } + return TOKEN_DOMAIN(domain, STRING_VALUES_DOMAIN); + } +}; + + +// ================================================ +// ================================================ +// ATTRIBUTE CONTROL HELPERS +// ================================================ +// ================================================ + +std::unique_ptr getAttributeControl(midi::ATTRIBUTE_CONTROL_TYPE ac_type) { + switch(ac_type) { + case midi::ATTRIBUTE_CONTROL_NOTE_DENSITY: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_POLYPHONY: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_TRACK_LEVEL_ONSET_DENSITY: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_PITCH_RANGE: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_GENRE: return std::make_unique(); + + case midi::ATTRIBUTE_CONTROL_TRACK_LEVEL_NOTE_DURATION: return std::make_unique(); + + case midi::ATTRIBUTE_CONTROL_POLYPHONY_QUANTILE: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_NOTE_DURATION_QUANTILE: return std::make_unique(); + + case midi::ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_DENSITY: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_BAR_LEVEL_ONSET_POLYPHONY: return std::make_unique(); + case midi::ATTRIBUTE_CONTROL_END: + throw std::runtime_error("encoder::getAttributeControl() midi::ATTRIBUTE_CONTROL_END is an invalid argument."); + } + throw std::runtime_error("encoder::getAttributeControl() switch statement missing case."); +} + +std::unique_ptr getAttributeControlStr(std::string &ac_type) { + auto descriptor = google::protobuf::GetEnumDescriptor(); + auto value_descriptor = descriptor->FindValueByName(ac_type); + if (value_descriptor == NULL) { + throw std::runtime_error("encoder::getAttributeControlStr() invalid attribute control type."); + } + return getAttributeControl(static_cast(value_descriptor->index())); +} + +std::vector> getAttributeControls() { + std::vector> acs; + for(int i=0; i(i))); + } + return acs; +} + +std::vector getAttributeControlTokenTypes() { + std::vector token_types; + for (const auto &ac : getAttributeControls()) { + token_types.push_back(ac->get_token_types()[0]); + } + return token_types; +} + +std::map getTokenToAttributeControlTypeMap() { + std::map token_to_ac_type; + for(int i=0; i(i); + auto ac = getAttributeControl(ac_type); + token_to_ac_type[ac->get_token_types()[0]] = ac_type; + } + return token_to_ac_type; +} + +std::multimap getTokenToAttributeControlTypeMultimap() { + std::multimap token_to_ac_type; + for(int i=0; i(i); + auto ac = getAttributeControl(ac_type); + for (const auto &tt : ac->get_token_types()) { + token_to_ac_type.insert({tt, ac_type}); + } + } + return token_to_ac_type; +} + +std::map TOKEN_TO_ATTRIBUTE_CONTROL_TYPE = getTokenToAttributeControlTypeMap(); +std::multimap TOKEN_TO_ATTRIBUTE_CONTROL_TYPE_MULTIMAP = getTokenToAttributeControlTypeMultimap(); + +midi::ATTRIBUTE_CONTROL_TYPE getAttributeControlTypeFromToken(midi::TOKEN_TYPE tt) { + auto result = TOKEN_TO_ATTRIBUTE_CONTROL_TYPE.find(tt); + if (result != TOKEN_TO_ATTRIBUTE_CONTROL_TYPE.end()) { + return result->second; + } + return midi::ATTRIBUTE_CONTROL_END; +} + +midi::ATTRIBUTE_CONTROL_TYPE getAttributeControlTypeFromTokenMultimap(midi::TOKEN_TYPE tt) { + auto result = TOKEN_TO_ATTRIBUTE_CONTROL_TYPE_MULTIMAP.find(tt); + if (result != TOKEN_TO_ATTRIBUTE_CONTROL_TYPE_MULTIMAP.end()) { + return result->second; + } + return midi::ATTRIBUTE_CONTROL_END; +} + +// deprecated +int get_token_domain_size(midi::TOKEN_TYPE tt) { + auto ac_type = getAttributeControlTypeFromTokenMultimap(tt); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + return getAttributeControl(ac_type)->get_token_domain_size(tt); + } + std::cout << "encoder::get_token_domain_size() token type = " << util_protobuf::enum_to_string(tt) << " not found." << std::endl; + throw std::runtime_error("encoder::get_token_domain_size() token type not found."); +} + +// deprecated +std::pair add_attribute_control_to_representation(midi::TOKEN_TYPE tt) { + return std::make_pair(tt, TOKEN_DOMAIN(get_token_domain_size(tt))); +} + + +std::vector> add_attribute_control_to_representation_v2(midi::ATTRIBUTE_CONTROL_TYPE ac_type) { + std::vector> token_domains; + auto ac = getAttributeControl(ac_type); + for (const auto &tt :ac->get_token_types()) { + token_domains.push_back(std::make_pair(tt, ac->get_token_domain(tt))); + } + return token_domains; +} + +std::vector> get_instrument_exclusive_token_types() { + std::vector> token_types; + for (const auto &ac : getAttributeControls()) { + if (ac->track_type == ATTRIBUTE_CONTROL_TRACK_TYPE_INSTRUMENT) { + if (ac->token_types_v3.size()) { + for (const auto &tt : ac->token_types_v3) { + token_types.push_back(std::make_tuple(std::get<0>(tt),std::get<1>(tt))); + } + } + else { + for (const auto &tt : ac->get_token_types()) { + token_types.push_back(std::make_tuple(tt,0)); + } + } + } + } + return token_types; +} + +std::vector> get_drum_exclusive_token_types() { + std::vector> token_types; + for (const auto &ac : getAttributeControls()) { + if (ac->track_type == ATTRIBUTE_CONTROL_TRACK_TYPE_DRUM) { + if (ac->token_types_v3.size()) { + for (const auto &tt : ac->token_types_v3) { + token_types.push_back(std::make_tuple(std::get<0>(tt),std::get<1>(tt))); + } + } + else { + for (const auto &tt : ac->get_token_types()) { + token_types.push_back(std::make_tuple(tt,0)); + } + } + } + } + return token_types; +} + +// refactoring attribute control graph functions +std::vector get_attribute_control_graph(ATTRIBUTE_CONTROL_LEVEL acl, midi::TOKEN_TYPE start, midi::TOKEN_TYPE end) { + std::vector token_order; + if (start != midi::TOKEN_NONE) { + token_order.push_back(start); + } + for (const auto &ac : getAttributeControls()) { + if (ac->control_level == acl) { + for (const auto &tt : ac->get_token_types()) { + token_order.push_back(tt); + } + } + } + if (end != midi::TOKEN_NONE) { + token_order.push_back(end); + } + return token_order; +} + +std::vector get_track_pre_instrument_attribute_control_graph() { + return get_attribute_control_graph(ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT, midi::TOKEN_TRACK, midi::TOKEN_INSTRUMENT); +} + +std::vector get_track_attribute_control_graph() { + return get_attribute_control_graph(ATTRIBUTE_CONTROL_LEVEL_TRACK, midi::TOKEN_INSTRUMENT, midi::TOKEN_BAR); +} + +std::vector get_bar_attribute_control_graph() { + return get_attribute_control_graph(ATTRIBUTE_CONTROL_LEVEL_BAR, midi::TOKEN_BAR, midi::TOKEN_TIME_SIGNATURE); +} + + +std::vector> get_attribute_control_graph_v2(ATTRIBUTE_CONTROL_LEVEL acl, std::tuple start, std::tuple end) { + std::vector> token_order; + if (std::get<0>(start) != midi::TOKEN_NONE) { + token_order.push_back(start); + } + for (const auto &ac : getAttributeControls()) { + if (ac->control_level == acl) { + if (ac->token_types_v3.size()) { + for (const auto &x : ac->token_types_v3) { + token_order.push_back(std::make_tuple(std::get<0>(x), std::get<1>(x))); + } + } + else { + for (const auto &tt : ac->get_token_types()) { + token_order.push_back(std::make_tuple(tt, 0)); + } + } + } + } + if (std::get<0>(end) != midi::TOKEN_NONE) { + token_order.push_back(end); + } + return token_order; +} + +std::vector> get_track_pre_instrument_attribute_control_graph_v2() { + return get_attribute_control_graph_v2(ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT, std::make_tuple(midi::TOKEN_TRACK, 0), std::make_tuple(midi::TOKEN_INSTRUMENT, 0)); +} + +std::vector> get_track_attribute_control_graph_v2() { + return get_attribute_control_graph_v2(ATTRIBUTE_CONTROL_LEVEL_TRACK, std::make_tuple(midi::TOKEN_INSTRUMENT, 0), std::make_tuple(midi::TOKEN_BAR, 0)); +} + +std::vector> get_bar_attribute_control_graph_v2() { + return get_attribute_control_graph_v2(ATTRIBUTE_CONTROL_LEVEL_BAR, std::make_tuple(midi::TOKEN_BAR, 0), std::make_tuple(midi::TOKEN_TIME_SIGNATURE, 0)); +} + +void override_attribute_controls(const std::shared_ptr &rep, midi::Piece *x, midi::Status *s) { + for (const auto &kv : rep->token_domains) { + auto ac_type = getAttributeControlTypeFromToken(kv.first); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + getAttributeControl(ac_type)->override_features(x, s); + } + } +} + +void compute_attribute_controls(const std::shared_ptr &rep, midi::Piece *x) { + for (const auto &kv : rep->token_domains) { + auto ac_type = getAttributeControlTypeFromToken(kv.first); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + getAttributeControl(ac_type)->compute_features(x); + } + } +} + +void compute_piece_level_attribute_controls(const std::shared_ptr &rep, midi::Piece *x) { + for (const auto &kv : rep->token_domains) { + auto ac_type = getAttributeControlTypeFromToken(kv.first); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + auto ac = getAttributeControl(ac_type); + if ((ac->control_level == ATTRIBUTE_CONTROL_LEVEL_PIECE) || (ac->precompute_on_piece)) { + ac->compute_piece_level_features(x); + } + } + } +} + +std::string compute_all_attribute_controls_py(std::string &piece_json) { + midi::Piece piece; + util_protobuf::string_to_protobuf(piece_json, &piece); + for (const auto &ac : getAttributeControls()) { + ac->compute_features(&piece); + } + return util_protobuf::protobuf_to_string(&piece); +} + +void append_track_pre_instrument_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::TrackFeatures *tf, bool is_drum) { + // order of tokens is important here + for (const auto &tt : getAttributeControlTokenTypes()) { + if (rep->token_domains.find(tt) != rep->token_domains.end()) { + auto ac_type = getAttributeControlTypeFromToken(tt); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + auto ac = getAttributeControl(ac_type); + if ((ac->control_level == ATTRIBUTE_CONTROL_LEVEL_TRACK_PRE_INSTRUMENT) && (ac->check_valid_track(is_drum))) { + ac->append_track_tokens(tokens, rep, tf); + } + } + } + } +} + +void append_track_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::TrackFeatures *tf, bool is_drum) { + // order of tokens is important here + for (const auto &tt : getAttributeControlTokenTypes()) { + if (rep->token_domains.find(tt) != rep->token_domains.end()) { + auto ac_type = getAttributeControlTypeFromToken(tt); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + auto ac = getAttributeControl(ac_type); + if ((ac->control_level == ATTRIBUTE_CONTROL_LEVEL_TRACK) && (ac->check_valid_track(is_drum))) { + ac->append_track_tokens(tokens, rep, tf); + } + } + } + } +} + +void append_bar_tokens(data_structures::TokenSequence *tokens, const std::shared_ptr &rep, midi::BarFeatures *bf, bool is_drum) { + // order of tokens is important here + for (const auto &tt : getAttributeControlTokenTypes()) { + if (rep->token_domains.find(tt) != rep->token_domains.end()) { + auto ac_type = getAttributeControlTypeFromToken(tt); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + auto ac = getAttributeControl(ac_type); + if ((ac->control_level == ATTRIBUTE_CONTROL_LEVEL_BAR) && (ac->check_valid_track(is_drum))) { + ac->append_bar_tokens(tokens, rep, bf); + } + } + } + } +} + +void set_track_masks(const std::shared_ptr &rep, std::vector &mask, midi::StatusTrack *track) { + for (const auto &kv : rep->token_domains) { + auto ac_type = getAttributeControlTypeFromToken(kv.first); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + auto ac = getAttributeControl(ac_type); + if (ac->is_track_control()) { + ac->set_track_mask(rep, mask, track); + } + } + } +} + +void set_bar_masks(const std::shared_ptr &rep, std::vector &mask, midi::StatusBar *bar) { + for (const auto &kv : rep->token_domains) { + auto ac_type = getAttributeControlTypeFromToken(kv.first); + if (ac_type != midi::ATTRIBUTE_CONTROL_END) { + auto ac = getAttributeControl(ac_type); + if (ac->is_bar_control()) { + ac->set_bar_mask(rep, mask, bar); + } + } + } +} + +} diff --git a/src/common/encoder/encoder_all.h b/src/common/encoder/encoder_all.h new file mode 100644 index 0000000000000000000000000000000000000000..91ca1d9710298cabffa6794a5457132e4bf069a7 --- /dev/null +++ b/src/common/encoder/encoder_all.h @@ -0,0 +1,113 @@ +#pragma once + +#include "encoder_base.h" +#include "util.h" +#include "attribute_control.h" +#include "../data_structures/track_type.h" +#include "../../inference/enum/velocity.h" +#include "../../inference/enum/timesigs.h" +#include "../../inference/enum/pretrain_group.h" +#include "../midi_parsing/util_protobuf.h" +#include "../../inference/protobuf/validate.h" + +// START OF NAMESPACE +namespace encoder { + +template +std::vector operator+(std::vector const &x, std::vector const &y) { + std::vector vec; + vec.reserve(x.size() + y.size()); + vec.insert(vec.end(), x.begin(), x.end()); + vec.insert(vec.end(), y.begin(), y.end()); + return vec; +} + +class ExpressiveEncoder : public ENCODER { +public: + ExpressiveEncoder() { + config = std::make_shared(); + config->both_in_one = true; + config->use_velocity_levels = true; + config->use_microtiming = true; + config->resolution = 12; + config->delta_resolution = 1920; + config->decode_resolution = config->delta_resolution; + + rep = std::make_shared(REPRESENTATION({ + {midi::TOKEN_PIECE_START, TOKEN_DOMAIN(2)}, + {midi::TOKEN_NUM_BARS, TOKEN_DOMAIN({4,8}, INT_VALUES_DOMAIN)}, + {midi::TOKEN_BAR, TOKEN_DOMAIN(1)}, + {midi::TOKEN_BAR_END, TOKEN_DOMAIN(1)}, + {midi::TOKEN_TIME_SIGNATURE, TOKEN_DOMAIN( + enums::YELLOW_TS_MAP,TIMESIG_MAP_DOMAIN)}, + {midi::TOKEN_TRACK, TOKEN_DOMAIN({ + midi::STANDARD_TRACK, + midi::STANDARD_DRUM_TRACK + },INT_VALUES_DOMAIN)}, + {midi::TOKEN_TRACK_END, TOKEN_DOMAIN(1)}, + {midi::TOKEN_INSTRUMENT, TOKEN_DOMAIN(enums::PRETRAIN_GROUPING,INT_MAP_DOMAIN)}, + {midi::TOKEN_NOTE_ONSET, TOKEN_DOMAIN(128)}, + {midi::TOKEN_NOTE_DURATION, TOKEN_DOMAIN(96)}, + {midi::TOKEN_TIME_ABSOLUTE_POS, TOKEN_DOMAIN(192)}, + {midi::TOKEN_FILL_IN_PLACEHOLDER, TOKEN_DOMAIN(1)}, + {midi::TOKEN_FILL_IN_START, TOKEN_DOMAIN(1)}, + {midi::TOKEN_FILL_IN_END, TOKEN_DOMAIN(1)}, + {midi::TOKEN_DELTA, TOKEN_DOMAIN(96)}, + {midi::TOKEN_DELTA_DIRECTION, TOKEN_DOMAIN(1)}, + {midi::TOKEN_VELOCITY_LEVEL, TOKEN_DOMAIN(128)}, + + add_attribute_control_to_representation(midi::TOKEN_MIN_NOTE_DURATION), + add_attribute_control_to_representation(midi::TOKEN_MAX_NOTE_DURATION), + add_attribute_control_to_representation(midi::TOKEN_MIN_POLYPHONY), + add_attribute_control_to_representation(midi::TOKEN_MAX_POLYPHONY), + add_attribute_control_to_representation(midi::TOKEN_DENSITY_LEVEL), + })); + + } + ~ExpressiveEncoder() {} + + void preprocess_piece(midi::Piece *p) { + util_protobuf::calculate_note_durations(p); + util_protobuf::update_av_polyphony_and_note_duration(p); + util_protobuf::update_note_density(p); + } + + void set_scheme(int res, int delta_res, int delta_vocab_size, int abs_pos_vocab_size) { + config->resolution = res; + config->delta_resolution = delta_res; + + rep = std::make_shared(REPRESENTATION({ + {midi::TOKEN_PIECE_START, TOKEN_DOMAIN(2)}, + {midi::TOKEN_NUM_BARS, TOKEN_DOMAIN({4,8}, INT_VALUES_DOMAIN)}, + {midi::TOKEN_BAR, TOKEN_DOMAIN(1)}, + {midi::TOKEN_BAR_END, TOKEN_DOMAIN(1)}, + {midi::TOKEN_TIME_SIGNATURE, TOKEN_DOMAIN( + enums::YELLOW_TS_MAP,TIMESIG_MAP_DOMAIN)}, + {midi::TOKEN_TRACK, TOKEN_DOMAIN({ + midi::STANDARD_TRACK, + midi::STANDARD_DRUM_TRACK + },INT_VALUES_DOMAIN)}, + {midi::TOKEN_TRACK_END, TOKEN_DOMAIN(1)}, + {midi::TOKEN_INSTRUMENT, TOKEN_DOMAIN(enums::PRETRAIN_GROUPING,INT_MAP_DOMAIN)}, + {midi::TOKEN_NOTE_ONSET, TOKEN_DOMAIN(128)}, + {midi::TOKEN_NOTE_DURATION, TOKEN_DOMAIN(96)}, + {midi::TOKEN_TIME_ABSOLUTE_POS, TOKEN_DOMAIN(abs_pos_vocab_size)}, + {midi::TOKEN_FILL_IN_PLACEHOLDER, TOKEN_DOMAIN(1)}, + {midi::TOKEN_FILL_IN_START, TOKEN_DOMAIN(1)}, + {midi::TOKEN_FILL_IN_END, TOKEN_DOMAIN(1)}, + {midi::TOKEN_DELTA, TOKEN_DOMAIN(delta_vocab_size)}, + {midi::TOKEN_DELTA_DIRECTION, TOKEN_DOMAIN(1)}, + + add_attribute_control_to_representation(midi::TOKEN_MIN_NOTE_DURATION), + add_attribute_control_to_representation(midi::TOKEN_MAX_NOTE_DURATION), + add_attribute_control_to_representation(midi::TOKEN_MIN_POLYPHONY), + add_attribute_control_to_representation(midi::TOKEN_MAX_POLYPHONY), + add_attribute_control_to_representation(midi::TOKEN_DENSITY_LEVEL), + + {midi::TOKEN_VELOCITY_LEVEL, TOKEN_DOMAIN(128)} + })); + } +}; + +} +// END OF NAMESPACE \ No newline at end of file diff --git a/src/common/encoder/encoder_base.h b/src/common/encoder/encoder_base.h new file mode 100644 index 0000000000000000000000000000000000000000..d7a09af3870594bffbbefd242467ef14cf3bca99 --- /dev/null +++ b/src/common/encoder/encoder_base.h @@ -0,0 +1,436 @@ +#pragma once + +#include + +#include "representation.h" +#include "util.h" + +#include "../data_structures/encoder_config.h" +#include "../data_structures/train_config.h" +#include "../data_structures/token_sequence.h" +#include "../midi_parsing/midi_io.h" + +// START OF NAMESPACE +namespace encoder { + +template +using matrix = std::vector>; + +std::vector resolve_bar_infill_tokens(std::vector &raw_tokens, const std::shared_ptr &rep) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "resolving bar infill" ); + int fill_pholder = rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER, 0); + int fill_start = rep->encode(midi::TOKEN_FILL_IN_START, 0); + int fill_end = rep->encode(midi::TOKEN_FILL_IN_END, 0); + + std::vector tokens; + + auto start_pholder = raw_tokens.begin(); + auto start_fill = raw_tokens.begin(); + auto end_fill = raw_tokens.begin(); + + while (start_pholder != raw_tokens.end()) { + start_pholder = next(start_pholder); // FIRST TOKEN IS PIECE_START ANYWAYS + auto last_start_pholder = start_pholder; + start_pholder = find(start_pholder, raw_tokens.end(), fill_pholder); + if (start_pholder != raw_tokens.end()) { + start_fill = find(next(start_fill), raw_tokens.end(), fill_start); + end_fill = find(next(end_fill), raw_tokens.end(), fill_end); + + // insert from last_start_pholder --> start_pholder + tokens.insert(tokens.end(), last_start_pholder, start_pholder); + tokens.insert(tokens.end(), next(start_fill), end_fill); + } + else { + // insert from last_start_pholder --> end of sequence (excluding fill) + start_fill = find(raw_tokens.begin(), raw_tokens.end(), fill_start); + tokens.insert(tokens.end(), last_start_pholder, start_fill); + } + } + return tokens; +} + +class ENCODER { +public: + + virtual ~ENCODER() {} + + // helper for simplicity + // also used to keep track of attribute controls used .... + + std::vector get_attribute_control_types() { + std::vector types; + auto enum_descriptor = google::protobuf::GetEnumDescriptor(); + for (auto c : attribute_control_types) { + types.push_back(enum_descriptor->FindValueByNumber(c)->name()); + } + return types; + } + + virtual void preprocess_piece(midi::Piece *p) { + // default is to do nothing + } + + std::vector encode(midi::Piece *p) { + preprocess_piece(p); + data_structures::TokenSequence ts = encode_piece(p); + return ts.tokens; + } + + std::vector encode_wo_preprocess(midi::Piece *p) { + data_structures::TokenSequence ts = encode_piece(p); + return ts.tokens; + } + + virtual void decode(std::vector &tokens, midi::Piece *p) { + if (config->do_multi_fill == true) { + tokens = resolve_bar_infill_tokens(tokens, rep); + } + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "AFTER BAR INFILL RESOLVED :: "); + for (int tok : tokens) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, pretty(tok)); + } + decode_track(tokens, p, rep, config); + } + + std::string midi_to_json(const std::string &filepath) { + midi::Piece p; + midi_io::ParseSong(filepath, &p, config); + preprocess_piece(&p); // add features that the encoder may need + std::string json_string; + google::protobuf::util::MessageToJsonString(p, &json_string); + return json_string; + } + + void midi_to_piece(const std::string& filepath, midi::Piece* p) { + midi_io::ParseSong(filepath, p, config); + preprocess_piece(p); + } + + std::vector midi_to_tokens(std::string &filepath) { + midi::Piece p; + midi_io::ParseSong(filepath, &p, config); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Parsed File :: ",util_protobuf::protobuf_to_string(&p))); + return encode(&p); + } + + void json_to_midi(std::string &json_string, std::string &filepath) { + midi::Piece p; + google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); + midi_io::write_midi(&p, filepath, -1); + } + + std::string json_to_json(std::string &json_string_in) { + midi::Piece p; + google::protobuf::util::JsonStringToMessage(json_string_in.c_str(), &p); + std::string json_string; + google::protobuf::util::MessageToJsonString(p, &json_string); + return json_string; + } + + void json_track_to_midi(std::string &json_string, std::string &filepath, int single_track) { + midi::Piece p; + google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); + midi_io::write_midi(&p, filepath, single_track); + } + + std::vector json_to_tokens(std::string &json_string) { + midi::Piece p; + google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); + return encode(&p); + } + + std::string tokens_to_json(std::vector &tokens) { + midi::Piece p; + decode(tokens, &p); + std::string json_string; + google::protobuf::util::MessageToJsonString(p, &json_string); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Decoded File :: ",json_string)); + return json_string; + } + + void resample_delta(midi::Piece *p) { + // This function rewrites the piece events time values to take in account their delta values + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_VERBOSE, "Resampling Piece with Delta values"); + + //We have to deal with overlapping notes by applying next notes onset delta to previous notes offset + std::map delta_to_apply; + int track_num = 0; + for (const auto &track : p->tracks()) { + int bar_num = 0; + for (const auto &bar : track.bars()) { + std::map>> pitch_to_events; + for (int i=0; ievents(event_idx); + pitch_to_events[event.pitch()].push_back(std::make_tuple(event_idx, event.time(), event.velocity(), event.delta())); + } + for (auto line : pitch_to_events) { + std::sort(line.second.begin(), line.second.end(), [](std::tuple a, std::tuple b) { + if (std::get<1>(a) < std::get<1>(b)) return true; + if (std::get<1>(b) < std::get<1>(a)) return false; + return (std::get<2>(a) < std::get<2>(b)); + }); + std::tuple last_event; + int last_offset_idx = -1; + for (auto const& e : line.second) { + // if onset, check last offset + if ((std::get<2>(e) > 0) && (last_offset_idx != -1)) { + if ((std::get<3>(e) != 0) && (std::get<1>(e) == p->events(last_offset_idx).time())) { + delta_to_apply[last_offset_idx] = std::get<3>(e); + } + } else if (std::get<2>(e) == 0) { + last_offset_idx = std::get<0>(e); + } + } + } + bar_num++; + } + track_num++; + } + + int current_res = config->resolution; + int target_res = config->decode_resolution; + p->set_resolution(target_res); + p->set_internal_ticks_per_quarter(target_res); + int old_time, new_time, delta; + std::vector> events_cache; + // Get all events and store in cache vector + + int num_events = p->events_size(); + for (int event_index=0; event_indexevents(event_index); + old_time = e.time(); + delta = e.delta(); + if (delta_to_apply.count(event_index) > 0) { + assert(delta_to_apply.count(event_index) == 1); + delta = delta_to_apply[event_index]; + } + // We round down to be safe + new_time = (int)(target_res * old_time / current_res); + //exclude negative times + new_time = std::max(new_time + delta, 0); + // Set new resampled time + e.set_time(new_time); + events_cache.push_back(std::make_tuple(event_index, e)); + } + // Sort events to replace in the correct order + sort(events_cache.begin(), events_cache.end(), [](std::tuple a, std::tuple b) { + return std::get<0>(a) < std::get<0>(b); + }); + // Clear all events now that they're cached + p->clear_events(); + // Reinject resampled events + for (const std::tuple &oe : events_cache) { + midi::Event *ne = p->add_events(); + ne->CopyFrom( std::get<1>(oe) ); + } + assert(num_events == p->events_size()); + } + + std::string resample_delta_json(std::string &json_string) { + std::string res_json_string; + midi::Piece p; + google::protobuf::util::JsonStringToMessage(json_string.c_str(), &p); + if (config->use_microtiming) { + resample_delta(&p); + } + google::protobuf::util::MessageToJsonString(p, &res_json_string); + return res_json_string; + } + + void tokens_to_json_array(std::vector> &seqs, std::vector &output) { + for (int i=0; i<(int)seqs.size(); i++) { + decode(seqs[i], &(output[i])); + } + } + + void tokens_to_midi(std::vector &tokens, std::string &filepath) { + midi::Piece p; + decode(tokens, &p); + midi_io::write_midi(&p, filepath, -1); + } + + // ==================== + // expose methods of rep that we need + + std::string pretty(int token) { + return rep->pretty(token); + } + + int vocab_size() { + return rep->vocab_size; + } + + // ==================== + + // below is a simplified refactor of the encoding process + // broken into clear functions to + // - encode notes within a bar + // - encode a bar + // - encode a track + // - encode a piece + + // ==================== + + void encode_notes(int bar_num, int track_num, midi::Piece *p, data_structures::TokenSequence *ts) { + const auto track = p->tracks(track_num); + const auto bar = track.bars(bar_num); + const auto is_drum = data_structures::is_drum_track(track.track_type()); + const int N_DURATION_TOKENS = rep->get_domain_size(midi::TOKEN_NOTE_DURATION); + int N_TIME_TOKENS = rep->get_domain_size(midi::TOKEN_DELTA); + + // group notes by onset time + std::vector onsets; + std::vector onsets_idx; + std::map> notes_by_onset; + std::map delta_onsets; + int idx = 0; + for (const auto &i : bar.events()) { + midi::Event event = p->events(i); + if ((event.internal_duration() > 0) && (event.velocity() > 0)) { + if (notes_by_onset.find(event.time()) == notes_by_onset.end()) { + onsets.push_back(event.time()); + onsets_idx.push_back(idx); + idx += 1; + } + notes_by_onset[event.time()].push_back(i); + delta_onsets[i] = event.delta(); + } + } + + int last_velocity = -1; + int onset; + int d_onset; + for (const auto &idx : onsets_idx) { + onset = onsets[idx]; + // checking for onset > 0 is to make things backwards compatible with the old representation + // however for randomly ordering onset times we need to include onset == 0 + if ((onset > 0)) { + ts->push_back( rep->encode(midi::TOKEN_TIME_ABSOLUTE_POS, onset) ); + } + + for (const auto &i : notes_by_onset[onset]) { + midi::Event event = p->events(i); + d_onset = delta_onsets[i]; + if (rep->has_token_type(midi::TOKEN_VELOCITY_LEVEL)) { + int current_velocity = rep->encode_partial(midi::TOKEN_VELOCITY_LEVEL, event.velocity()); + if ((current_velocity > 0) && (current_velocity != last_velocity)) { + ts->push_back( rep->encode(midi::TOKEN_VELOCITY_LEVEL, event.velocity()) ); + last_velocity = current_velocity; + } + } + if (config->use_microtiming) { + if (d_onset < 0) { + ts->push_back( rep->encode(midi::TOKEN_DELTA_DIRECTION, 0) ); + d_onset *= -1; + } + d_onset = std::min(N_TIME_TOKENS - 1, d_onset); + if (d_onset > 0) { + ts->push_back( rep->encode(midi::TOKEN_DELTA, d_onset) ); + } + } + ts->push_back( rep->encode(midi::TOKEN_NOTE_ONSET, event.pitch()) ); + if (!is_drum) { + ts->push_back( rep->encode(midi::TOKEN_NOTE_DURATION, std::min(event.internal_duration(), N_DURATION_TOKENS)-1) ); + } + } + } + } + + void encode_bar(int bar_num, int track_num, midi::Piece *p, data_structures::TokenSequence *ts, bool infill) { + auto track = p->tracks(track_num); + const auto bar = track.bars(bar_num); + const auto is_drum = data_structures::is_drum_track(track.track_type()); + + ts->on_bar_start(p, rep); + + if (infill) { + ts->push_back( rep->encode(midi::TOKEN_FILL_IN_START, 0) ); + encode_notes(bar_num, track_num, p, ts); + ts->push_back( rep->encode(midi::TOKEN_FILL_IN_END, 0) ); + } + else { + ts->push_back( rep->encode(midi::TOKEN_BAR, 0) ); + + midi::BarFeatures *bf = util_protobuf::GetBarFeatures(&track, bar_num); + append_bar_tokens(ts, rep, bf, is_drum); + + if (rep->has_token_type(midi::TOKEN_TIME_SIGNATURE)) { + ts->push_back( rep->encode(midi::TOKEN_TIME_SIGNATURE, std::make_tuple(bar.ts_numerator(), bar.ts_denominator())) ); + } + + if ((config->do_multi_fill) && (config->multi_fill.find(std::make_pair(track_num,bar_num)) != config->multi_fill.end())) { + ts->push_back( rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER, 0) ); + } + else { + encode_notes(bar_num, track_num, p, ts); + } + ts->push_back( rep->encode(midi::TOKEN_BAR_END, 0) ); + } + } + + void encode_track(int track_num, midi::Piece *p, data_structures::TokenSequence *ts) { + const auto track = p->tracks(track_num); + const auto is_drum = data_structures::is_drum_track(track.track_type()); + const auto f = util_protobuf::GetTrackFeatures(p, track_num); + + ts->on_track_start(p, rep); + + ts->push_back( rep->encode(midi::TOKEN_TRACK, track.track_type()) ); + + append_track_pre_instrument_tokens(ts, rep, f, is_drum); + + if (rep->has_token_type(midi::TOKEN_INSTRUMENT)) { + int inst = track.instrument(); + ts->push_back( rep->encode(midi::TOKEN_INSTRUMENT, inst) ); + } + + append_track_tokens(ts, rep, f, is_drum); + + for (int i=0; ipush_back( rep->encode(midi::TOKEN_TRACK_END, 0) ); + } + + data_structures::TokenSequence encode_piece(midi::Piece *p) { + + // make sure that rep does not try use deprecated note encodings + if ((!rep->has_token_type(midi::TOKEN_NOTE_DURATION)) || (!rep->has_token_type(midi::TOKEN_TIME_ABSOLUTE_POS))) { + throw std::runtime_error("ERROR: ENCODING PIECE WITH DEPRECATED NOTE ENCODINGS"); + } + + data_structures::TokenSequence ts(rep); + + ts.push_back( rep->encode( + midi::TOKEN_PIECE_START, std::min((int)config->do_multi_fill,rep->get_domain_size(midi::TOKEN_PIECE_START)-1))); + + if (rep->has_token_type(midi::TOKEN_NUM_BARS)) { + ts.push_back( rep->encode(midi::TOKEN_NUM_BARS, util_protobuf::GetNumBars(p)) ); + } + + for (int i=0; itracks_size(); i++) { + encode_track(i, p, &ts); + } + + if (config->do_multi_fill) { + for (const auto &track_bar : config->multi_fill) { + encode_bar(std::get<1>(track_bar), std::get<0>(track_bar), p, &ts, true); + } + } + + return ts; + } + + std::shared_ptr get_rep() { + return rep; + } + + std::shared_ptr config; + std::shared_ptr rep; + std::vector attribute_control_types; +}; + +} +// END OF NAMESPACE diff --git a/src/common/encoder/representation.h b/src/common/encoder/representation.h new file mode 100644 index 0000000000000000000000000000000000000000..3c517c0890684bed4843ab15e200f5373dbf4a69 --- /dev/null +++ b/src/common/encoder/representation.h @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "token_domain.h" +#include "../data_structures/verbosity.h" +#include "../midi_parsing/util_protobuf.h" + +// START OF NAMESPACE +namespace encoder { + +class REPRESENTATION { + +/* +This class describes the token representation, ie. the vocabulary. It tracks the index and domain size (how many values per token type) +and is used for one hot encoding and for sampling. +Therefore, it encodes a token from its base state (TOKEN_TYPE, value) to its vectorial one-hot encoded state. It also allows the reverse process. +*/ + +public: + REPRESENTATION(std::vector> spec) { + + /* + params: spec - vector tht holds each token type and the allocated domain size in the vocabulary + */ + + // intialize vocabulary size + vocab_size = 0; + for (const auto &token_domain : spec) { + // loop through each token type + midi::TOKEN_TYPE tt = std::get<0>(token_domain); // token type + TOKEN_DOMAIN domain = std::get<1>(token_domain); // token domain size + + int index = 0; + for (const auto &value : domain.map_items) { + // loop through each allocated token in the token domain + int token = vocab_size + std::get<1>(value); + TOKEN_TUPLE toktup = std::make_tuple(tt,std::get<0>(value)); + + if (domain.repeat_tt.size() == 1) { + token = forward[std::make_tuple(domain.repeat_tt[0],std::get<0>(value))]; + } + + forward[toktup] = token; + if (domain.input_types[index] != TI_INT) { + forward[std::make_tuple(tt,std::get<1>(value))] = token; + } + + if (domain.repeat_tt.size() == 0) { + backward[token] = toktup; + backward_types[token] = domain.input_types[index]; + } + index++; + } + vocab_size += domain.output_domain.size(); // add the current domain size to the vocabulary size + domains.insert( std::make_pair(tt,domain.output_domain.size()) ); + token_domains.insert( std::make_pair(tt,domain) ); + } + } + int encode(midi::TOKEN_TYPE tt, TOKEN_VARIANT value) { + std::tuple key = std::make_tuple(tt,value); + auto it = forward.find(key); + if (it == forward.end()) { + std::ostringstream buffer; + auto tdit = token_domains.find(tt); + if (tdit == token_domains.end()) { + buffer << "ENCODER ERROR : TOKEN TYPE " << util_protobuf::enum_to_string(tt) << " IS NOT IN REPRESENTATION"; + } + else { + TOKEN_INPUT_TYPE ti = tdit->second.input_types[0]; + buffer << "ENCODER ERROR : VALUE (" << token_variant_to_string(ti, value) << ") NOT IN DOMAIN FOR TOKEN TYPE " << util_protobuf::enum_to_string(tt); + } + throw std::runtime_error(buffer.str()); + } + return it->second; + } + int encode_partial(midi::TOKEN_TYPE tt, TOKEN_VARIANT value) { + auto it = token_domains.find(tt); + if (it == token_domains.end()) { + throw std::runtime_error("midi::TOKEN_TYPE NOT PART OF THIS REPRESENTATION"); + } + return it->second.encode(value); + } + int encode_partial_py_int(midi::TOKEN_TYPE tt, int value) { + auto it = token_domains.find(tt); + if (it == token_domains.end()) { + throw std::runtime_error("midi::TOKEN_TYPE NOT PART OF THIS REPRESENTATION"); + } + return it->second.encode(value); + } + void token_in_range(int token) { + if (token >= vocab_size) { + throw std::runtime_error("TOKEN IS LARGER THAN VOCAB SIZE!"); + } + if (token < 0) { + throw std::runtime_error("TOKEN IS NEGATIVE!"); + } + } + int decode(int token) { + token_in_range(token); + if (backward_types[token] != TI_INT) { + throw std::runtime_error("TOKEN CAN NOT BE DECODED AS INT"); + } + return std::get(std::get<1>(backward[token])); + } + std::string decode_string(int token) { + token_in_range(token); + if (backward_types[token] != TI_STRING) { + throw std::runtime_error("TOKEN CAN NOT BE DECODED AS STRING"); + } + return std::get(std::get<1>(backward[token])); + } + std::tuple decode_timesig(int token) { + token_in_range(token); + if (backward_types[token] != TI_TIMESIG) { + throw std::runtime_error("TOKEN CAN NOT BE DECODED AS TIMESIG"); + } + return std::get>(std::get<1>(backward[token])); + } + int max_token() { + return vocab_size; + } + int get_domain_size(midi::TOKEN_TYPE tt) { + auto it = domains.find(tt); + if (it == domains.end()) { + return 0; + } + return it->second; + } + bool in_domain(midi::TOKEN_TYPE tt, int value) { + auto it = token_domains.find(tt); + if (it != token_domains.end()) { + return it->second.output_domain.find(value) != it->second.output_domain.end(); + } + return false; + } + + std::vector get_num_bars_domain() { + std::vector model_dims; + auto itt = token_domains.find(midi::TOKEN_NUM_BARS); + if (itt != token_domains.end()) { + for (const auto &value : itt->second.input_domain) { + model_dims.push_back( std::get(value) ); + } + } + return model_dims; + } + std::vector> get_time_signature_domain() { + std::vector> timesigs; + auto itt = token_domains.find(midi::TOKEN_TIME_SIGNATURE); + if (itt != token_domains.end()) { + for (const auto &ts : itt->second.input_domain) { + timesigs.push_back( std::get>(ts) ); + } + } + else { + // the standard models without time signatures only trained on 4/4 + timesigs.push_back( std::make_tuple(4,4) ); + } + return timesigs; + } + + void check_token(int token) { + auto it = backward.find(token); + if (it == backward.end()) { + std::ostringstream buffer; + buffer << "ENCODER ERROR : TOKEN " << token << "IS NOT IN REPRESENTATION"; + throw std::runtime_error(buffer.str()); + } + } + bool is_token_type(int token, midi::TOKEN_TYPE tt) { + check_token(token); + return std::get<0>(backward[token]) == tt; + } + midi::TOKEN_TYPE get_token_type(int token) { + check_token(token); + return std::get<0>(backward[token]); + } + bool has_token_type(midi::TOKEN_TYPE tt) { + return token_domains.find(tt) != token_domains.end(); + } + bool has_token_types(std::vector tts) { + for (const auto &tt : tts) { + if (!has_token_type(tt)) { + return false; + } + } + return true; + } + + template + std::vector get_mask(T value) { + return std::vector(vocab_size, value); + } + + template + std::set get_mask_token_types(std::vector &mask) { + std::set tts; + for (int i=0; i<(int)mask.size(); i++) { + if (mask[i] > 0) { + tts.insert( get_token_type(i) ); + } + } + return tts; + } + + template + void show_mask_token_types(std::vector &mask) { + std::set tts = get_mask_token_types(mask); + data_structures::LOGGER("MASK TOKEN TYPES :: "); + for (const auto &tt : tts) { + data_structures::LOGGER(data_structures::to_str(util_protobuf::enum_to_string(tt), ", "), false); + } + data_structures::LOGGER(""); + } + + template + void set_mask(midi::TOKEN_TYPE tt, std::vector values, std::vector &mask, T mask_value) { + auto it = token_domains.find(tt); + if (it != token_domains.end()) { + for (const auto &value : values) { + if (value == -1) { + for (const auto &v : it->second.input_domain) { + mask[encode(tt, v)] = mask_value; + } + } + else { + mask[encode(tt, value)] = mask_value; + } + } + } + } + + template + void set_mask(midi::TOKEN_TYPE tt, std::vector values, std::vector &mask, T mask_value, STRING_VECTOR_FLAG x) { + auto it = token_domains.find(tt); + if (it != token_domains.end()) { + for (const auto &value : values) { + mask[encode(tt, value)] = mask_value; + } + } + } + + std::vector encode_to_one_hot(midi::TOKEN_TYPE tt, std::vector values) { + std::vector x(vocab_size,0); + set_mask(tt, values, x, 1); + return x; + } + + std::vector get_type_mask(std::vector tts) { + std::vector mask(vocab_size,0); + for (int i=0; i(v)); + } + else if (ti == TI_STRING) { + value_str = std::get(v); + } + else if (ti == TI_TIMESIG) { + auto ts = std::get>(v); + value_str = std::to_string(std::get<0>(ts)) + "/" + std::to_string(std::get<1>(ts)); + } + else { + throw std::runtime_error("THE TOKEN HAS NO INVALID TOKEN_INPUT_TYPE"); + } + return value_str; + } + + std::string pretty(int token) { + auto token_value = backward[token]; + TOKEN_INPUT_TYPE ti = backward_types[token]; + return util_protobuf::enum_to_string(std::get<0>(token_value)) + std::string(" = ") + token_variant_to_string(ti, std::get<1>(token_value)); + } + + std::string pretty_type(int token) { + auto token_value = backward[token]; + return util_protobuf::enum_to_string(std::get<0>(token_value)); + } + + void show(std::vector &tokens) { + for (const auto &token : tokens) { + data_structures::LOGGER(pretty(token)); + } + } + + void show_token_types() { + for (const auto &token : domains) { + data_structures::LOGGER(data_structures::to_str("REP TOKENS :: ", util_protobuf::enum_to_string(token.first))); + } + } + + void show_mapping() { + for (int i=0; isecond.output_domain.size() < 128); + } + return false; + } + + int vocab_size; + std::map forward; + std::map backward; + std::map backward_types; + + std::map domains; // maps each token type to its domain output size + std::map token_domains; // maps each token type to its token domain +}; + +} +// END OF NAMESPACE diff --git a/src/common/encoder/token_domain.h b/src/common/encoder/token_domain.h new file mode 100644 index 0000000000000000000000000000000000000000..d601ac5651d8e1481bd3f3046fa2a94881e44fe8 --- /dev/null +++ b/src/common/encoder/token_domain.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "midi.pb.h" + +namespace encoder { + + using TOKEN_VARIANT = std::variant>; // The different value types that a token value can have + using TOKEN_TUPLE = std::tuple; // A complete token in represented by its token type and token value + + struct INT_RANGE_FLAG {}; + struct INT_VALUES_FLAG {}; + struct INT_MAP_FLAG {}; + struct STRING_VALUES_FLAG {}; + struct STRING_MAP_FLAG {}; + struct TIMESIG_VALUES_FLAG {}; + struct TIMESIG_MAP_FLAG {}; + struct CONTINUOUS_FLAG {}; + struct STRING_VECTOR_FLAG {}; + + INT_RANGE_FLAG RANGE_DOMAIN; + INT_VALUES_FLAG INT_VALUES_DOMAIN; + INT_MAP_FLAG INT_MAP_DOMAIN; + STRING_VALUES_FLAG STRING_VALUES_DOMAIN; + STRING_MAP_FLAG STRING_MAP_DOMAIN; + TIMESIG_VALUES_FLAG TIMESIG_VALUES_DOMAIN; + TIMESIG_MAP_FLAG TIMESIG_MAP_DOMAIN; + CONTINUOUS_FLAG CONTINUOUS_DOMAIN; + STRING_VECTOR_FLAG STRING_VECTOR; + + // different representation for an input token + enum TOKEN_INPUT_TYPE { + TI_INT, + TI_STRING, + TI_TIMESIG + }; + + class TOKEN_DOMAIN { + + /* + Represents the domain of a token type. It encodes a token by mapping it input type (TOKEN_INPUT_TYPE) to a unique integer. + Either it automatically increments this unique integer when a new token value is added, or you can provide a custom mapping (in this case contiguous_ouput is used to ensure a contiguous output domain, ie. [1,m] with some m) + This integer is unique within the given domain. The representation.h class then creates unique ids from multiple token domains + */ + + public: + + TOKEN_DOMAIN(size_t n) { // for token types that take values in [0,n] + if (n > 512) { + throw std::invalid_argument("TOKEN DOMAIN SIZE IS TOO LARGE!"); + } + for (int value=0; value<(int)n; value++) { + add(value); + input_types.push_back( TI_INT ); + } + } + TOKEN_DOMAIN(int min, int max, INT_RANGE_FLAG x) { // for token types that take values in [min,max[ + for (int value=min; value values, INT_VALUES_FLAG x) { // for a custom set of int values + for (const auto &value : values) { + add(value); + input_types.push_back( TI_INT ); + } + } + TOKEN_DOMAIN(std::map values, INT_MAP_FLAG x) { // for a custom set of input int to output mappings + for (const auto &kv : values) { + add(kv.first, kv.second); + input_types.push_back( TI_INT ); + } + } + TOKEN_DOMAIN(std::vector values, STRING_VALUES_FLAG x) { // for a custom set of string values + for (const auto &value : values) { + add(value); + input_types.push_back( TI_STRING ); + } + } + TOKEN_DOMAIN(std::map values, STRING_MAP_FLAG x) { // for a custom set of input string to output mappings + for (const auto &kv : values) { + add(kv.first, kv.second); + input_types.push_back( TI_STRING ); + } + } + TOKEN_DOMAIN(std::vector> values, TIMESIG_VALUES_FLAG x) { // for a custom set of int pairs (time signatures) + for (const auto &value : values) { + add(value); + input_types.push_back( TI_TIMESIG ); + } + } + TOKEN_DOMAIN(std::map,int> values, TIMESIG_MAP_FLAG x) { // for a custom set of input int pairs (time signatures) to output mappings + for (const auto &kv : values) { + add(kv.first, kv.second); + input_types.push_back( TI_TIMESIG ); + } + } + TOKEN_DOMAIN(int n, midi::TOKEN_TYPE rtt) { + for (int value=0; value values, CONTINUOUS_FLAG x) { + throw std::runtime_error("NOT CURRENTLY IMPLEMENTED!"); + } + void add_internal(TOKEN_VARIANT x, int y) { + // map token input x to token output y and respectively add them to their domains + map_items.push_back( std::make_tuple(x,y) ); + mapping.insert( std::make_pair(x,y) ); + output_domain.insert( y ); + input_domain.insert( x ); + } + void add(TOKEN_VARIANT x) { // add integer value to domain + add_internal(x, (int)input_domain.size()); + } + void add(TOKEN_VARIANT x, int y) { + // ensure contiguous domain + if (contiguous_output.find(y) == contiguous_output.end()) { + int current_size = contiguous_output.size(); + contiguous_output.insert( std::make_pair(y,current_size) ); + } + add_internal(x, contiguous_output[y]); + } + int encode(TOKEN_VARIANT x) { + auto it = mapping.find(x); + if (it == mapping.end()) { + throw std::runtime_error("TOKEN VALUE IS OUT OF RANGE!"); + } + return it->second; + } + int token_count; // number of individual tokens in the domain + std::vector repeat_tt; // repeat token types, idk what a repeat token type is + std::vector> map_items; + std::map mapping; // same thing as map_items + std::set input_domain; + std::set output_domain; + std::map contiguous_output; // ensure contiguous output domain + std::vector input_types; // keep track of data types + }; +} diff --git a/src/common/encoder/util.h b/src/common/encoder/util.h new file mode 100644 index 0000000000000000000000000000000000000000..7c75917f868df30403a4bbbd108a76511d9fb716 --- /dev/null +++ b/src/common/encoder/util.h @@ -0,0 +1,216 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "midi.pb.h" +#include "../midi_parsing/util_protobuf.h" +#include "attribute_control.h" +#include "../../inference/enum/constants.h" +#include "../../common/data_structures/encoder_config.h" +#include "../../common/data_structures/token_sequence.h" +#include "../../inference/enum/pretrain_group.h" + +// START OF NAMESPACE +namespace encoder { + + +// below is a simplified refactor of the encoding process +// broken into clear functions to +// - encode notes within a bar +// - encode a bar +// - encode a track +// - encode a piece + +void decode_track(std::vector &tokens, midi::Piece *p, const std::shared_ptr &rep, const std::shared_ptr &ec) { + p->set_resolution(ec->resolution); + + std::map inst_to_track; + midi::Event *e = NULL; + midi::Track *t = NULL; + midi::Bar *b = NULL; + int current_time, current_note_time, current_instrument, delta_direction, delta_total; + int beat_length = 0; + int track_count = 0; + int bar_count = 0; + int last_token = -1; + int last_abs_token = -1; + int current_velocity = 100; + + std::set offset_remain; + + for (const auto &token : tokens) { + if (rep->is_token_type(token, midi::TOKEN_TRACK)) { + current_time = 0; // restart the time + current_note_time = 0; + current_instrument = 0; // reset instrument + delta_direction = 1; + delta_total = 0; + offset_remain.clear(); + if (track_count >= p->tracks_size()) { + t = p->add_tracks(); + } + else { + t = p->mutable_tracks(track_count); + } + t->set_track_type( (midi::TRACK_TYPE)rep->decode(token) ); + util_protobuf::GetTrackFeatures(p, track_count); + } + else if (rep->is_token_type(token, midi::TOKEN_TRACK_END)) { + track_count++; + t = NULL; + } + else if (rep->is_token_type(token, midi::TOKEN_BAR)) { + // when we start new bar we need to decrement time of remaining offsets + for (const auto &index : offset_remain) { + midi::Event *e = p->mutable_events(index); + e->set_time( (int)(((e->time() - beat_length * ec->resolution)*(p->resolution())/ec->resolution))); + } + current_time = 0; // restart the time + current_note_time = 0; + delta_direction = 1; + delta_total = 0; + beat_length = 4; // default value optionally overidden with TIME_SIGNATURE + if (t) { + b = t->add_bars(); + } + bar_count++; + } + else if (rep->is_token_type(token, midi::TOKEN_TIME_SIGNATURE)) { + std::tuple ts = rep->decode_timesig(token); + beat_length = 4 * std::get<0>(ts) / std::get<1>(ts); + b->set_ts_numerator( std::get<0>(ts) ); + b->set_ts_denominator( std::get<1>(ts) ); + } + else if (rep->is_token_type(token, midi::TOKEN_BAR_END)) { + if (b) { + b->set_internal_beat_length(beat_length); + } + current_time = beat_length * p->resolution(); + current_note_time = current_time; + } + else if (rep->is_token_type(token, midi::TOKEN_TIME_ABSOLUTE_POS)) { + current_time = rep->decode(token); // simply update instead of increment + current_note_time = current_time; + delta_direction = 1; + delta_total = 0; + } + else if (rep->is_token_type(token, midi::TOKEN_DELTA_DIRECTION)) { + delta_direction = -1; + delta_total = 0; + } + else if (rep->is_token_type(token, midi::TOKEN_DELTA)) { + last_abs_token = last_token; + int delta_val = rep->decode(token); + delta_total += delta_direction * delta_val; + + } + else if (rep->is_token_type(token, midi::TOKEN_INSTRUMENT)) { + if (t) { + current_instrument = rep->decode(token); + t->set_instrument( current_instrument ); + } + } + else if (rep->is_token_type(token, midi::TOKEN_VELOCITY_LEVEL)) { + current_velocity = rep->decode(token); + } + else if (rep->is_token_type(token, midi::TOKEN_NOTE_ONSET)) { + if (b && t) { + + if (data_structures::is_drum_track(t->track_type())) { + + int current_note_index = p->events_size(); + current_note_time = current_time; + e = p->add_events(); + e->set_pitch( rep->decode(token) ); + e->set_velocity( current_velocity ); + e->set_time( current_note_time ); + + e->set_delta( delta_total ); + delta_total = 0; + delta_direction = 1; + b->add_events( current_note_index ); + b->set_internal_has_notes( true ); + + current_note_index = p->events_size(); + e = p->add_events(); + e->set_pitch( rep->decode(token) ); + e->set_velocity( 0 ); + e->set_time( current_note_time + 1 ); + b->add_events( current_note_index ); + b->set_internal_has_notes( true ); + + } + } + } + else if (rep->is_token_type(token, midi::TOKEN_NOTE_DURATION)) { + if (b && t && (last_token >= 0) && (rep->is_token_type(last_token, midi::TOKEN_NOTE_ONSET))) { + + // add onset + int current_note_index = p->events_size(); + current_note_time = current_time; + e = p->add_events(); + e->set_pitch( rep->decode(last_token) ); + e->set_velocity( current_velocity ); + e->set_time( current_note_time ); + e->set_delta( delta_total ); + delta_total = 0; + delta_direction = 1; + b->add_events( current_note_index ); + + // add offset + current_note_index = p->events_size(); + e = p->add_events(); + e->set_pitch( rep->decode(last_token) ); + e->set_velocity( 0 ); + e->set_time( current_note_time + rep->decode(token) + 1 ); + e->set_delta( 0 ); + + if (e->time() <= beat_length * p->resolution()) { + b->add_events( current_note_index ); + } + else { + // we need to add this to a later bar + offset_remain.insert( current_note_index ); + } + + b->set_internal_has_notes( true ); + } + } + else if (rep->is_token_type(token, midi::TOKEN_GENRE)) { + midi::TrackFeatures *f; + if (!t->internal_features_size()) { + f = t->add_internal_features(); + } + else { + f = t->mutable_internal_features(0); + } + f->set_genre_str( rep->decode_string(token) ); + } + + // insert offsets from note_duration tokens when possible + std::vector to_remove; + for (const auto &index : offset_remain) { + if ((int)p->events(index).time() <= current_time) { + b->add_events( index ); + to_remove.push_back( index ); + } + } + for (const auto &index : to_remove) { + offset_remain.erase(index); + } + + last_token = token; + } + p->add_internal_valid_segments(0); + p->add_internal_valid_tracks((1<tracks_size())-1); +} + +} +// END OF NAMESPACE diff --git a/src/common/midi_parsing/adjacent_range.h b/src/common/midi_parsing/adjacent_range.h new file mode 100644 index 0000000000000000000000000000000000000000..e0d973f25046334025211b6303177ab064bc398d --- /dev/null +++ b/src/common/midi_parsing/adjacent_range.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include + +namespace midi_parsing { + + template class adjacent_iterator { + public: + adjacent_iterator(FwdIt first, FwdIt last) + : m_first(first), m_next(first == last ? first : std::next(first)) { } + + bool operator!=(const adjacent_iterator& other) const { + return m_next != other.m_next; + } + + adjacent_iterator& operator++() { + ++m_first; + ++m_next; + return *this; + } + + typedef typename std::iterator_traits::reference Ref; + typedef std::pair Pair; + + Pair operator*() const { + return Pair(*m_first, *m_next); + } + + private: + FwdIt m_first; + FwdIt m_next; + }; + + template class adjacent_range { + public: + adjacent_range(FwdIt first, FwdIt last) + : m_first(first), m_last(last) { } + + adjacent_iterator begin() const { + return adjacent_iterator(m_first, m_last); + } + + adjacent_iterator end() const { + return adjacent_iterator(m_last, m_last); + } + + private: + FwdIt m_first; + FwdIt m_last; + }; + + template auto make_adjacent_range(C& c) -> adjacent_range { + return adjacent_range(c.begin(), c.end()); + } +} \ No newline at end of file diff --git a/src/common/midi_parsing/feature_extraction.h b/src/common/midi_parsing/feature_extraction.h new file mode 100644 index 0000000000000000000000000000000000000000..60d38c6383d1ed118e0156fd64f063d05db27693 --- /dev/null +++ b/src/common/midi_parsing/feature_extraction.h @@ -0,0 +1,554 @@ +#pragma once + +#include +#include + +#include "../../common/midi_parsing/util_protobuf.h" +#include "../../../libraries/protobuf/build/midi.pb.h" + +namespace feature_extraction { + +class FEATURE_EXTRACTOR { +public: + + virtual ~FEATURE_EXTRACTOR() {} + + virtual void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + throw std::runtime_error("FEATURE_EXTRACTOR::compute_track_feature() not implemented"); + } + + virtual void compute_piece_feature(midi::Piece *x, midi::Features *f, std::string &filepath) { + throw std::runtime_error("FEATURE_EXTRACTOR::compute_piece_feature() not implemented"); + } + + void compute_feature(midi::Piece *x, midi::Features *f, std::string &filepath) { + if (track_level) { + for (int track_num=0; track_numtracks_size(); track_num++) { + compute_track_feature(x, track_num, f, filepath); + } + } + else { + compute_piece_feature(x, f, filepath); + } + } + + bool quantized; + bool track_level; + +}; + +class PitchRangeFeature : public FEATURE_EXTRACTOR { +public: + + PitchRangeFeature() { + quantized = false; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + const auto track = x->tracks(track_num); + + // ignore drum tracks + if (data_structures::is_drum_track(track.track_type())) { + return; + } + + int min_pitch = INT_MAX; + int max_pitch = INT_MIN; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + if (x->events(event_index).pitch() < min_pitch) { + min_pitch = x->events(event_index).pitch(); + } + if (x->events(event_index).pitch() > max_pitch) { + max_pitch = x->events(event_index).pitch(); + } + } + } + } + + auto fm = f->add_pitch_range(); + fm->set_instrument(track.instrument()); + fm->set_min(min_pitch); + fm->set_max(max_pitch); + } +}; + +class DownbeatProportionFeature : public FEATURE_EXTRACTOR { +public: + + DownbeatProportionFeature() { + quantized = true; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + const auto track = x->tracks(track_num); + + if (!x->internal_has_time_signatures()) { + return; + } + + int downbeat_count = 0; + int onset_count = 0; + + std::map onset_counts; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + + onset_counts[x->events(event_index).time()] += 1; + + if (x->events(event_index).time() == 0) { + downbeat_count++; + } + onset_count++; + + } + } + } + + int max_non_downbeat = 1; + for (const auto &kv : onset_counts) { + if (kv.first != 0) { + max_non_downbeat = std::max(max_non_downbeat, kv.second); + } + } + + auto fm = f->add_downbeat_proportion(); + fm->set_instrument(track.instrument()); + fm->set_is_drum(data_structures::is_drum_track(track.track_type())); + fm->set_filepath(filepath); + fm->set_track_num(track_num); + + fm->set_downbeat_proportion(float(onset_counts[0]) / float(max_non_downbeat)); + + } +}; + +std::map compute_metric_depth_counts(midi::Piece *x, int track_num, int max_depth, int offset) { + const auto track = x->tracks(track_num); + + int max_duple_depth = 0; + int max_triplet_depth = 0; + int total_depth = max_depth * 2; + + while ((x->internal_ticks_per_quarter() % int(pow(2,max_duple_depth))) == 0) { + max_duple_depth += 1; + } + while ((x->internal_ticks_per_quarter() * 2) % (int(pow(2,max_triplet_depth)) * 3) == 0) { + max_triplet_depth += 1; + } + + max_duple_depth = std::min(max_duple_depth, max_depth); + max_triplet_depth = std::min(max_triplet_depth, max_depth); + + std::map metric_depth_counts; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + bool found_depth = false; + for (int i=0; iinternal_ticks_per_quarter()) / int(pow(2,i)); + if (((x->events(event_index).time()) + offset) % period == 0) { + metric_depth_counts[2*i] += 1; + found_depth = true; + break; + } + } + if (!found_depth) { + for (int i=0; iinternal_ticks_per_quarter() * 2) / (int(pow(2,i)) * 3); + if ((x->events(event_index).time() + offset) % period == 0) { + metric_depth_counts[2*i + 1] += 1; + found_depth = true; + break; + } + } + } + if (!found_depth) { + metric_depth_counts[total_depth] += 1; + } + } + } + } + return metric_depth_counts; +} + +class MetricDepthFeature : public FEATURE_EXTRACTOR { +public: + + MetricDepthFeature() { + quantized = false; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + + int max_depth = 6; + auto metric_depth_counts = compute_metric_depth_counts(x, track_num, max_depth, 0); + const auto track = x->tracks(track_num); + + auto fm = f->add_metric_depth(); + fm->set_filepath(filepath); + fm->set_track_num(track_num); + fm->set_instrument(track.instrument()); + fm->set_is_drum(data_structures::is_drum_track(track.track_type())); + fm->set_has_time_signatures(x->internal_has_time_signatures()); + fm->set_tpq(x->internal_ticks_per_quarter()); + for (int i=0; i<=max_depth*2; i++) { + if (metric_depth_counts.find(i) == metric_depth_counts.end()) { + fm->add_metric_depth(0); + } else { + fm->add_metric_depth(metric_depth_counts[i]); + } + } + } +}; + +class MostFrequentMetricDepthFeature : public FEATURE_EXTRACTOR { +public: + + MostFrequentMetricDepthFeature() { + quantized = false; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + + int max_depth = 6; + auto metric_depth_counts = compute_metric_depth_counts(x, track_num, max_depth, 0); + const auto track = x->tracks(track_num); + + + int max_count = 0; + int max_index = 0; + for (const auto &kv : metric_depth_counts) { + if (kv.second > max_count) { + max_count = kv.second; + max_index = kv.first; + } + } + + auto fm = f->add_most_frequent_metric_depth(); + fm->set_filepath(filepath); + fm->set_track_num(track_num); + fm->set_instrument(track.instrument()); + fm->set_is_drum(data_structures::is_drum_track(track.track_type())); + fm->set_has_time_signatures(x->internal_has_time_signatures()); + fm->set_tpq(x->internal_ticks_per_quarter()); + fm->set_most_frequent_metric_depth(max_index); + } + +}; + +class MedianMetricDepthFeature : public FEATURE_EXTRACTOR { +public: + + MedianMetricDepthFeature() { + quantized = false; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + + int max_depth = 6; + auto metric_depth_counts = compute_metric_depth_counts(x, track_num, max_depth, 0); + const auto track = x->tracks(track_num); + + + int total = 0; + for (const auto &kv : metric_depth_counts) { + total += kv.second; + } + + int median_count = total / 2; + int median_depth = 0; + int cumulative = 0; + for (const auto &kv : metric_depth_counts) { + if ((median_count >= cumulative) && (median_count < cumulative + kv.second)) { + median_depth = kv.first; + } + cumulative += kv.second; + } + + + auto fm = f->add_median_metric_depth(); + fm->set_filepath(filepath); + fm->set_track_num(track_num); + fm->set_instrument(track.instrument()); + fm->set_is_drum(data_structures::is_drum_track(track.track_type())); + fm->set_has_time_signatures(x->internal_has_time_signatures()); + fm->set_tpq(x->internal_ticks_per_quarter()); + fm->set_median_metric_depth(median_depth); + } +}; + +class AlignedMetricDepthFeature : public FEATURE_EXTRACTOR { +public: + + AlignedMetricDepthFeature() { + quantized = true; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + + x->set_internal_ticks_per_quarter(12); + + int max_depth = 6; + int tpq = x->internal_ticks_per_quarter(); + const auto track = x->tracks(track_num); + + if (tpq > 10000) { + throw std::runtime_error("AlignedMetricDepthFeature::compute_track_feature() must have tpq <= 10000."); + } + + int best_score_offset = 0; + int best_score = max_depth * 100; + for (int offset=0; offset= cumulative) && (total < cumulative + kv.second)) { + median_score = kv.first; + } + cumulative += kv.second; + } + + if (median_score < best_score) { + best_score = median_score; + best_score_offset = offset; + } + + } + + if (best_score_offset != 0) { + std::cout << "FOUND INVALID MIDI FILE :: " << filepath << std::endl; + } + + auto fm = f->add_aligned_metric_depth(); + fm->set_filepath(filepath); + fm->set_track_num(track_num); + fm->set_instrument(track.instrument()); + fm->set_is_drum(data_structures::is_drum_track(track.track_type())); + + fm->set_aligned_offset(best_score_offset); + } + +}; + +class SimultaneousOnsetFeature : public FEATURE_EXTRACTOR { +public: + + SimultaneousOnsetFeature() { + quantized = false; + track_level = true; + } + + void compute_track_feature(midi::Piece *x, int track_num, midi::Features *f, std::string &filepath) { + + const auto track = x->tracks(track_num); + + int simultaneous_onset_count = 0; + for (const auto &bar : track.bars()) { + std::map onsets; + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + onsets[x->events(event_index).time()] += 1; + } + } + for (const auto &kv : onsets) { + simultaneous_onset_count += (int)(kv.second > 1); + } + } + + auto fm = f->add_simultaneous_onset(); + fm->set_filepath(filepath); + fm->set_track_num(track_num); + fm->set_instrument(track.instrument()); + fm->set_is_drum(data_structures::is_drum_track(track.track_type())); + + fm->set_simultaneous_onset_count(simultaneous_onset_count); + } + +}; + +template +double standardDeviation(std::vector &xs) { + double total = 0; + double stdev = 0; + for (const auto &x : xs) { + total += x; + } + double mean = total / xs.size(); + for (const auto &x : xs) { + stdev += pow(x - mean, 2); + } + return sqrt(stdev / xs.size()); +} + +template +T median(std::vector &xs) { + std::sort(xs.begin(), xs.end()); + return xs[xs.size() / 2]; +} + +// basic feature to measure drum presence +class DrumPresenceFeature : public FEATURE_EXTRACTOR { +public: + + DrumPresenceFeature() { + quantized = true; + track_level = false; + } + + void compute_piece_feature(midi::Piece *x, midi::Features *f, std::string &filepath) { + std::map drum_counts_per_bar; + std::map inst_counts_per_bar; + for (const auto &track : x->tracks()) { + int bar_num = 0; + bool is_drum = data_structures::is_drum_track(track.track_type()); + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + auto event = x->events(event_index); + if (event.velocity()) { + if (is_drum) { + drum_counts_per_bar[bar_num]++; + } else { + inst_counts_per_bar[bar_num]++; + } + } + } + bar_num++; + } + } + + int count = 0; + for (const auto &kv : inst_counts_per_bar) { + count += (int)(drum_counts_per_bar.find(kv.first) != drum_counts_per_bar.end()); + } + + if (inst_counts_per_bar.size() == 0) { + return; // invalid midi file or only drums + } + + auto fm = f->add_drum_presence(); + fm->set_filepath(filepath); + fm->set_drum_presence((double)count / (double)inst_counts_per_bar.size()); + + } +}; + +class BeatStabilityFeature : public FEATURE_EXTRACTOR { +public: + + BeatStabilityFeature() { + quantized = true; + track_level = false; + } + + void compute_piece_feature(midi::Piece *x, midi::Features *f, std::string &filepath) { + + int max_beat_num = 0; + std::map beat_total_weights; + std::map,int> onset_weights; + for (const auto &track : x->tracks()) { + int total_beat_num = 0; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + auto event = x->events(event_index); + int beat_num = total_beat_num + event.time() / 12; + if (event.velocity()) { + // try extra weight for drums + onset_weights[std::make_tuple(beat_num,event.time() % 12)] += event.velocity(); // * (is_drum ? 2 : 1); + beat_total_weights[beat_num] += event.velocity(); + } + } + if (abs(bar.internal_beat_length() - std::round(bar.internal_beat_length())) > 1e-4) { + return; // the piece is invalid and we cannot compute + } + total_beat_num += bar.internal_beat_length(); + } + max_beat_num = std::max(max_beat_num, total_beat_num); + } + + double max_weight = 0; + for (auto &kv : beat_total_weights) { + max_weight = std::max(max_weight, kv.second); + } + + std::vector bar_weights; + for (int i=0; iadd_beat_stability(); + fm->set_filepath(filepath); + fm->set_beat_stability_stdev(standardDeviation(bar_weights)); + fm->set_beat_stability_median(median(bar_weights)); + + } + +}; + + +std::unique_ptr getFeature(std::string feature_name) { + if (feature_name == "BEAT_STABILITY") return std::make_unique(); + if (feature_name == "PITCH_RANGE") return std::make_unique(); + if (feature_name == "METRIC_DEPTH") return std::make_unique(); + if (feature_name == "MEDIAN_METRIC_DEPTH") return std::make_unique(); + if (feature_name == "MOST_FREQUENT_METRIC_DEPTH") return std::make_unique(); + if (feature_name == "DOWNBEAT_PROPORTION") return std::make_unique(); + if (feature_name == "ALIGNED_METRIC_DEPTH") return std::make_unique(); + if (feature_name == "SIMULTANEOUS_ONSET") return std::make_unique(); + if (feature_name == "DRUM_PRESENCE") return std::make_unique(); + + throw std::runtime_error("feature_extraction::getFeature() switch statement missing case."); +} + + +std::string compute_features(std::string &filepath, std::vector &feature_names) { + + + std::map> features; + + for (const auto &feature_name : feature_names) { + features[(int)getFeature(feature_name)->quantized].push_back(feature_name); + } + + auto encoder_config = std::make_shared(); + encoder_config->resolution = 12; + + midi::Features f; + + for (int i=0; i<2; i++) { + if (features.find(i) != features.end()) { + midi::Piece p; + encoder_config->unquantized = 1-i; + midi_io::ParseSong(filepath, &p, encoder_config); + for (const auto &feature_name : features[i]) { + getFeature(feature_name)->compute_feature(&p, &f, filepath); + } + } + } + + return util_protobuf::protobuf_to_string(&f); +} + + +} \ No newline at end of file diff --git a/src/common/midi_parsing/midi_io.h b/src/common/midi_parsing/midi_io.h new file mode 100644 index 0000000000000000000000000000000000000000..47eb6960765a7abbcab0a070aca4566c1011fbde --- /dev/null +++ b/src/common/midi_parsing/midi_io.h @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../libraries/midifile/include/Binasc.h" +#include "../../../libraries/midifile/include/MidiFile.h" + +#include "../../common/midi_parsing/util_protobuf.h" +#include "../../common/data_structures/track_type.h" +#include "../../common/data_structures/encoder_config.h" + +#include "../../common/midi_parsing/adjacent_range.h" + +#include + +// START OF NAMESPACE +namespace midi_io { + +#define QUIET_CALL(noisy) { \ + std::cout.setstate(std::ios_base::failbit);\ + std::cerr.setstate(std::ios_base::failbit);\ + (noisy);\ + std::cout.clear();\ + std::cerr.clear();\ +} + +float quantize_beat_float(double x, double TPQ, double SPQ, double cut=.5) { + return (int)((x / TPQ * SPQ) + (1.-cut)) * (TPQ / SPQ); +} + +int quantize_beat(double x, double TPQ, double SPQ, double cut=.5) { + return (int)quantize_beat_float(x, TPQ, SPQ, cut); +} + +int get_time_difference(double x, double y, double xpq, double spq, double tempo, int beats_per_note) { + return (int)(1000 * 60 * beats_per_note * (y - x) )/(4 * xpq * tempo); +} + +bool event_comparator(const midi::Event a, const midi::Event b) { + if (a.time() != b.time()) { + return a.time() < b.time(); + } + if (std::min(a.velocity(),1) != std::min(b.velocity(),1)) { + return std::min(a.velocity(),1) < std::min(b.velocity(),1); + } + return a.pitch() < b.pitch(); +} +bool event_pair_comparator(const std::pair a, const std::pair b) { + if (a.first.time() != b.first.time()) { + return a.first.time() < b.first.time(); + } + if (std::min(a.first.velocity(),1) != std::min(b.first.velocity(),1)) { + return std::min(a.first.velocity(),1) < std::min(b.first.velocity(),1); + } + return a.first.pitch() < b.first.pitch(); +} + +using TRACK_IDENTIFIER = std::tuple; + +class MidiParsedData { +public: + smf::MidiFile midi_file; + int track_count; + int ticks_per_quarter_note; + + MidiParsedData(std::string file_path) { + QUIET_CALL(midi_file.read(file_path)); + midi_file.makeAbsoluteTicks(); + midi_file.linkNotePairs(); + track_count = midi_file.getTrackCount(); + ticks_per_quarter_note = midi_file.getTPQ(); + } +}; + +class Parser { +public: + Parser(std::string filepath, midi::Piece *piece, const std::shared_ptr &config) { + Parse(filepath, piece, config); + } + static const int DRUM_CHANNEL = 9; + std::shared_ptr ec; + int track_count; + int TPQ; + int SPQ; + int current_track; + int max_tick; + int tempo; + smf::MidiEvent *mevent; + std::map track_map; + std::map rev_track_map; // transposed of track_map + std::map> timesigs; + std::map> bars; // value is (beatlength,count,num,dem) + std::vector> events; // events split into tracks + std::array instruments; // instruments on each channel. + + void SetMemberVariables(const std::shared_ptr &config, MidiParsedData* parsed_file) { + ec = config; + TPQ = parsed_file->ticks_per_quarter_note; + SPQ = ec->resolution; + if (TPQ < SPQ) { + throw std::runtime_error("MIDI FILE HAS INVALID TICKS PER QUARTER."); + } + } + + void FillPiece(midi::Piece* piece, MidiParsedData* parsed_file, const std::shared_ptr &config) { + piece->set_resolution(SPQ); + piece->set_internal_ticks_per_quarter(TPQ); + max_tick = 0; + current_track = 0; + + for (int track = 0; track < parsed_file->track_count; track++) { + current_track = track; + std::fill(instruments.begin(), instruments.end(), 0); // zero instruments + for (int event = 0; event < parsed_file->midi_file[track].size(); event++) { + mevent = &(parsed_file->midi_file[track][event]); + if (mevent->isPatchChange()) { + handle_patch_message(mevent); + } + else if (mevent->isTimeSignature()) { + handle_time_sig_message(mevent); + } + else if (mevent->isTempo()) { + tempo = mevent->getTempoBPM(); + piece->set_tempo(tempo); + } + else if (mevent->isNoteOn() || mevent->isNoteOff()) { + handle_note_message(mevent); + } + } + } + + if (max_tick <= 0) { + throw std::runtime_error("MIDI FILE HAS NO NOTES"); + } + + piece->set_internal_has_time_signatures(timesigs.size() > 0); + } + + void ProcessTimeSignatures(MidiParsedData* parsed_file){ + // add a timesig at beginning and end + // and then make a mapping from tick to bar_number and bar_length + int count = 0; + if (timesigs.find(0) == timesigs.end()) { + timesigs[0] = std::make_tuple(parsed_file->ticks_per_quarter_note * 4, 4, 4); // assume 4/4 + } + // if we do max_tick + TPQ instead we end up with an extra bar + timesigs[max_tick] = std::make_tuple(0, 0, 0); // no bar length + for (const auto& p : midi_parsing::make_adjacent_range(timesigs)) { + if (std::get<0>(p.first.second) > 0) { + for (int t = p.first.first; t < p.second.first; t += std::get<0>(p.first.second)) { + auto ts = p.first.second; + bars[t] = std::make_tuple(std::get<0>(ts), count, std::get<1>(ts), std::get<2>(ts)); + count++; + } + } + } + } + + void CreateMidiPiece(midi::Piece* piece, MidiParsedData* parsed_file) { + // construct the piece + midi::Track* track = NULL; + midi::Bar* bar = NULL; + midi::Event* event = NULL; + + for (int track_num = 0; track_num < (int)events.size(); track_num++) { + + // sort the events in each track + // at this point ticks are still absolute + std::sort(events[track_num].begin(), events[track_num].end(), event_comparator); + + // add track and track metadata + track = piece->add_tracks(); + track->set_instrument(std::get<2>(rev_track_map[track_num])); + track->set_track_type( + (midi::TRACK_TYPE)std::get<3>(rev_track_map[track_num])); + + // add bars and bar metadata + for (const auto& bar_info : bars) { + bar = track->add_bars(); + bar->set_internal_beat_length(std::get<0>(bar_info.second) / parsed_file->ticks_per_quarter_note); + bar->set_ts_numerator(std::get<2>(bar_info.second)); + bar->set_ts_denominator(std::get<3>(bar_info.second)); + } + + // add events + for (int j = 0; j < (int)events[track_num].size(); j++) { + int velocity = events[track_num][j].velocity(); + int tick = events[track_num][j].time(); + auto bar_info = get_bar_info(tick, velocity > 0); + + bar = track->mutable_bars(std::get<2>(bar_info)); // bar_num + bar->set_internal_has_notes(true); + + bar->add_events(piece->events_size()); + event = piece->add_events(); + event->CopyFrom(events[track_num][j]); + + int rel_tick = round((double)(tick - std::get<0>(bar_info)) / parsed_file->ticks_per_quarter_note * SPQ); + event->set_time(rel_tick); // relative + } + } + } + + + void Parse(std::string filepath, midi::Piece* piece, const std::shared_ptr &config) { + MidiParsedData parsed_file = MidiParsedData(filepath); + SetMemberVariables(config, &parsed_file); + FillPiece(piece, &parsed_file, config); + ProcessTimeSignatures(&parsed_file); + CreateMidiPiece(piece, &parsed_file); + } + + int infer_voice(int channel, int inst) { + int track_type = midi::STANDARD_TRACK; + if (channel == DRUM_CHANNEL) { + track_type = midi::STANDARD_DRUM_TRACK; + } + return track_type; + } + + TRACK_IDENTIFIER join_track_info(int track, int channel, int inst) { + return std::make_tuple(track, channel, inst, infer_voice(channel, inst)); + } + + std::tuple get_bar_info(int tick, bool is_onset) { + // returns bar_start, bar_length, bar_num tuple + auto it = bars.upper_bound(tick); + if (it == bars.begin()) { + throw std::runtime_error("CAN'T GET BAR INFO FOR TICK!"); + } + it = prev(it); + if ((it->first == tick) && (!is_onset)) { + // if the note is an offset and the time == the start of the bar + // push it back to the previous bar + if (it == bars.begin()) { + throw std::runtime_error("CAN'T GET BAR INFO FOR TICK!"); + } + it = prev(it); + } + return std::make_tuple(it->first, std::get<0>(it->second), std::get<1>(it->second)); + } + + void handle_patch_message(smf::MidiEvent *mevent) { + int channel = mevent->getChannelNibble(); + instruments[channel] = (int)((*mevent)[1]); + } + + void handle_time_sig_message(smf::MidiEvent *mevent) { + int numerator = (*mevent)[3]; + int denominator = 1<<(*mevent)[4]; + int barlength = (double)(TPQ * 4 * numerator / denominator); + + if (barlength >= 0) { + timesigs[mevent->tick] = std::make_tuple(barlength, numerator, denominator); + } + } + + std::tuple get_time_sig(double tick) { + if (timesigs.empty()) { + return std::make_tuple(TPQ * 4, 4, 4); + } + + auto it = timesigs.lower_bound(tick); + + if (it == timesigs.begin()) { + return std::make_tuple(TPQ * 4, 4, 4); + } + + --it; + return it->second; + } + + int beats_per_note(double tick) { + std::tuple time_sig = get_time_sig(tick); + return std::get<2>(time_sig); + } + + bool is_event_offset(smf::MidiEvent *mevent) { + return ((*mevent)[2]==0) || (mevent->isNoteOff()); + } + + void add_event(TRACK_IDENTIFIER &track_info, int tick, int pitch, int velocity, int delta) { + midi::Event event; + event.set_time( tick ); + event.set_pitch( pitch ); + event.set_velocity( velocity ); + event.set_delta( delta ); + events[track_map[track_info]].push_back( event ); + } + + void handle_note_message(smf::MidiEvent *mevent) { + int channel = mevent->getChannelNibble(); + int pitch = (int)(*mevent)[1]; + int velocity = (int)(*mevent)[2]; + + if ((!mevent->isLinked()) && (channel != 9)) { + // we do not include unlinked notes unless they are drum + return; + } + + if (mevent->isNoteOff()) { + velocity = 0; // sometimes this is not the case + } + + int tick = mevent->tick; + float float_tick = (float)mevent->tick; + int unquantized_tick = mevent->tick; + if (!ec->unquantized) { + tick = quantize_beat(mevent->tick, TPQ, SPQ); + float_tick = quantize_beat_float(mevent->tick, TPQ, SPQ); + } + + bool is_offset = is_event_offset(mevent); + + // ignore note offsets at start of file + if (is_offset && (tick==0)) { + return; + } + + int delta = 0; + if (ec->use_microtiming) { + delta = ec->step_to_delta(unquantized_tick - float_tick, TPQ); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Using delta :: ", delta)); + } + + TRACK_IDENTIFIER track_info = join_track_info(current_track,channel,instruments[channel]); + + // track_info has info for new tracks per channel. If we can't find that info, we update track_map indicating there's a new + // track, and then we push a vector of events (preparing to fill that vector with events in the future). + // update track map + if (track_map.find(track_info) == track_map.end()) { + int current_size = track_map.size(); + track_map[track_info] = current_size; + rev_track_map[current_size] = track_info; + events.push_back( std::vector() ); + } + + // make all drum notes really short + if (channel == 9) { + if (!is_offset) { + add_event(track_info, tick, pitch, velocity, delta); + add_event(track_info, tick + (TPQ/SPQ), pitch, 0, delta); + } + } + else { + add_event(track_info, tick, pitch, velocity, delta); + } + + max_tick = std::max(max_tick, mevent->tick); + } +}; + +void ParseSong(std::string filepath, midi::Piece *midi_piece, const std::shared_ptr &encoder_config) { + Parser parser(filepath, midi_piece, encoder_config); +} + +void write_midi(midi::Piece* p, std::string& path, int single_track = -1) { + static const int DRUM_CHANNEL = 9; + + if (p->tracks_size() >= 15) { + throw std::runtime_error("TOO MANY TRACKS FOR MIDI OUTPUT"); + } + smf::MidiFile outputfile; + outputfile.absoluteTicks(); + outputfile.setTicksPerQuarterNote(p->resolution()); + outputfile.addTempo(0, 0, p->tempo()); + outputfile.addTrack(16); // ensure drum channel + + int track_num = 0; + for (const auto &track : p->tracks()) { + if ((single_track < 0) || (track_num == single_track)) { + int bar_start_time = 0; + int patch = track.instrument(); + int channel = enums::SAFE_TRACK_MAP[track_num]; + if (data_structures::is_drum_track(track.track_type())) { + channel = DRUM_CHANNEL; + } + outputfile.addPatchChange(channel, 0, channel, patch); + + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + const midi::Event e = p->events(event_index); + outputfile.addNoteOn( + channel, // same as channel + bar_start_time + e.time(), // time + channel, // channel + e.pitch(), // pitch + e.velocity()); // velocity + } + bar_start_time += bar.internal_beat_length() * p->resolution(); + } + } + track_num++; + } + + outputfile.sortTracks(); // make sure data is in correct order + outputfile.write(path.c_str()); // write Standard MIDI File twinkle.mid +} +} +// END OF NAMESPACE diff --git a/src/common/midi_parsing/util_protobuf.h b/src/common/midi_parsing/util_protobuf.h new file mode 100644 index 0000000000000000000000000000000000000000..4832a29f67be204613ad8562b3f2e863bb819c17 --- /dev/null +++ b/src/common/midi_parsing/util_protobuf.h @@ -0,0 +1,749 @@ +#pragma once + +#include + +#include +#include +#include +#include "../../common/data_structures/track_type.h" +#include "../../common/data_structures/encoder_config.h" +#include "../../common/data_structures/verbosity.h" + +#include "../../inference/enum/density.h" +#include "../../inference/enum/constants.h" +#include "../../inference/enum/gm.h" +#include "../../inference/random.h" + +#ifndef M_LOG2E +#define M_LOG2E 1.4426950408889634074 +#endif + +// START OF NAMESPACE +namespace util_protobuf { + + // Checks if bar has features and returns them + midi::BarFeatures* GetBarFeatures(midi::Track *track, int bar_num) { + if ((bar_num < 0) || (bar_num >= track->bars_size())) { + throw std::runtime_error("BAR FEATURE REQUEST OUT OF RANGE"); + } + midi::Bar* bar = track->mutable_bars(bar_num); + if (bar->internal_features_size() == 0) { + return bar->add_internal_features(); + } + return bar->mutable_internal_features(0); + } + + // Checks if tracks has features and returns them + midi::TrackFeatures* GetTrackFeatures(midi::Piece* midi_piece, int track_num) { + if ((track_num < 0) || (track_num >= midi_piece->tracks_size())) { + throw std::runtime_error("TRACK FEATURE REQUEST OUT OF RANGE"); + } + //we return a pointer to the mutable track object with index track_num and we store the pointer in midi_track + midi::Track* midi_track = midi_piece->mutable_tracks(track_num); + if (midi_track->internal_features_size() == 0) { + //adds new element to end of field and returns a pointer. The returned track features is mutable and will have none of its fields set. + return midi_track->add_internal_features(); + } + //returns a pointer to the underlying mutable track object with index track_num and we return the pointer + return midi_track->mutable_internal_features(0); + } + + midi::PieceFeatures* GetPieceFeatures(midi::Piece* midi_piece) { + if (midi_piece->internal_features_size() == 0) { + return midi_piece->add_internal_features(); + } + return midi_piece->mutable_internal_features(0); + } + + // Get the number of bars in a piece + int GetNumBars(midi::Piece* midi_piece) { + if (midi_piece->tracks_size() == 0) { + return 0; + } + std::set track_num_bars; + for (const auto &track : midi_piece->tracks()) { + track_num_bars.insert(track.bars_size()); + } + if (track_num_bars.size() > 1) { + throw std::runtime_error("Each track must have the same number of bars!"); + } + //we dereference the pointer to the first element in the set (in this case the only element) + return *track_num_bars.begin(); + } + + // ================================================================ + // Functions to update the note_polyphony field in the midi::Tracks of a midi::Piece + // ================================================================ + + midi::Note CreateNote(int start, int end, int pitch) { + midi::Note note; + note.set_start(start); + note.set_end(end); + note.set_pitch(pitch); + return note; + } + + // slightly different way to get notes + std::vector getNotes(midi::Piece* piece, int track_start, int track_end, int bar_start, int bar_end, bool onset_only_drums) { + midi::Event current_midi_event; + std::vector notes; + std::map onsets; // key = pitch, value = start time + for (int track_num=track_start; track_numtracks_size()); + assert(bar_num < piece->tracks(track_num).bars_size()); + const midi::Track track = piece->tracks(track_num); + const midi::Bar bar = track.bars(bar_num); + for (const int event_id : bar.events()) { + current_midi_event = piece->events(event_id); + if (current_midi_event.velocity() > 0) { + // need to account for bar offset to get correct start time + int start_time = current_time + current_midi_event.time(); + if ((data_structures::is_drum_track(track.track_type())) && (onset_only_drums)) { + notes.push_back(util_protobuf::CreateNote(start_time, start_time + 1, current_midi_event.pitch())); + } + else { + onsets[current_midi_event.pitch()] = start_time; + } + } + else { + auto last_event_with_pitch = onsets.find(current_midi_event.pitch()); + int end_time = current_time + current_midi_event.time(); + if (last_event_with_pitch != onsets.end()) { + notes.push_back(util_protobuf::CreateNote(last_event_with_pitch->second, end_time, last_event_with_pitch->first)); + onsets.erase(last_event_with_pitch); + } + } + } + current_time += piece->resolution() * bar.internal_beat_length(); + } + } + return notes; + } + + // Go over all the bars and convert midi::events to midi::notes + std::vector IterateAndConvert(midi::Piece* midi_piece, const midi::Track* current_track, bool bool_drum_track, int* duration_in_ticks) { + midi::Event current_midi_event; + std::vector notes; + std::map onsets; + int bar_start = 0; + for (int bar_num = 0; bar_num < current_track->bars_size(); bar_num++) { + const midi::Bar bar = current_track->bars(bar_num); + for (auto event_id : bar.events()) { + current_midi_event = midi_piece->events(event_id); + if (current_midi_event.velocity() > 0) { + // need to account for bar offset to get correct start time + onsets[current_midi_event.pitch()] = bar_start + current_midi_event.time(); + } + else { + auto last_event_with_pitch = onsets.find(current_midi_event.pitch()); + // need to account for bar offset to get correct end time + int end_time = bool_drum_track ? last_event_with_pitch->second + 1 : bar_start + current_midi_event.time(); + if (last_event_with_pitch != onsets.end()) { + midi::Note note = CreateNote(last_event_with_pitch->second, end_time, last_event_with_pitch->first); + notes.push_back(note); + onsets.erase(last_event_with_pitch); + } + } + *duration_in_ticks = std::max(*duration_in_ticks, bar_start + current_midi_event.time()); + } + bar_start += midi_piece->resolution() * bar.internal_beat_length(); + } + return notes; + } + + // Get a specific track from a midi piece and convert its midi::events to midi::notes + std::vector TrackEventsToNotes(midi::Piece* midi_piece, int track_num, int* duration_in_ticks) { + bool bool_drum_track = data_structures::is_drum_track(midi_piece->tracks(track_num).track_type()); //TODO: this should be renamed is_drum_track = check_if_drum_track()... refactor + const midi::Track* current_track = &(midi_piece->tracks(track_num)); + std::vector notes = IterateAndConvert(midi_piece, current_track, bool_drum_track, duration_in_ticks); //TODO: This is a mayor change, but maybe the .proto shouldn't keep the events int the Piece, and instead keep them in the track message type + return notes; + } + + // Get the notes playing simultaneously per tick and return the tick with most note count. + int GetTrackMaxPolyphony(std::vector& notes, int duration_in_ticks) { + int max_polyphony = 0; + std::vector flat_roll(duration_in_ticks, 0); + for (const auto ¬e : notes) { + for (int tick = note.start(); tick < note.end(); tick++) { + flat_roll[tick]++; + max_polyphony = std::max(flat_roll[tick], max_polyphony); + } + } + return max_polyphony; + } + + // ================================================================ + // Functions to convert a polyphonic track to a monophonic one + // ================================================================ + + // We create an array of monophonic events + // we iterate over events + // if an event starts, we flag it. + // if another event starts before the flag is down, we force the first event to end and + // be pushed in the array. We then flag the new event as being played + // if the event ends before another starts, we just push it in the array + + void UpdateHasNotes(midi::Piece* midi_piece) { + int track_num = 0; + for (const auto &track : midi_piece->tracks()) { + int bar_num = 0; + for (const auto &bar : track.bars()) { + bool has_notes = false; + for (const auto &event_index : bar.events()) { + if (midi_piece->events(event_index).velocity() > 0) { + has_notes = true; + break; + } + } + midi_piece->mutable_tracks(track_num)->mutable_bars(bar_num)->set_internal_has_notes(has_notes); + bar_num++; + } + track_num++; + } + } + + // ======================================================================== + // RANDOM SEGMENT SELECTION FOR TRAINING + // + // 1. we select an index of a random segment + + + void UpdateValidSegments(midi::Piece* midi_piece, int seglen, int min_tracks) { + UpdateHasNotes(midi_piece); + midi_piece->clear_internal_valid_segments(); + midi_piece->clear_internal_valid_tracks(); + + if (midi_piece->tracks_size() < min_tracks) { return; } // no valid tracks + + int min_non_empty_bars = round(seglen * .75); + int num_bars = GetNumBars(midi_piece); + + for (int start = 0; start < num_bars - seglen + 1; start++) { + + // check that all time sigs are supported + bool is_four_four = true; + + // check which tracks are valid + midi::ValidTrack vtracks; + std::map used_track_types; + for (int track_num = 0; track_num < midi_piece->tracks_size(); track_num++) { + int non_empty_bars = 0; + for (int k = 0; k < seglen; k++) { + if (midi_piece->tracks(track_num).bars(start + k).internal_has_notes()) { + non_empty_bars++; + } + } + if (non_empty_bars >= min_non_empty_bars) { + vtracks.add_tracks(track_num); + } + } + + // check if there are enough tracks + bool enough_tracks = vtracks.tracks_size() >= min_tracks; + + if (enough_tracks && is_four_four) { + midi::ValidTrack* v = midi_piece->add_internal_valid_tracks_v2(); + v->CopyFrom(vtracks); + midi_piece->add_internal_valid_segments(start); + } + } + } + + // ================================================================ + // Non-factorized functions for inference + // ================================================================ + + inline double midigpt_log2(const long double x) { + return std::log(x) * M_LOG2E; + } + + template + T clip(const T& n, const T& lower, const T& upper) { + return std::max(lower, std::min(n, upper)); + } + + template + std::vector quantile(std::vector& x, std::vector qs) { + std::vector vals; + for (const auto &q : qs) { + if (x.size()) { + int index = std::min((int)round((double)x.size() * q), (int)x.size() - 1); + std::nth_element(x.begin(), x.begin() + index, x.end()); + vals.push_back(x[index]); + } + else { + vals.push_back(0); + } + } + return vals; + } + + template + T min_value(std::vector& x) { + auto result = std::min_element(x.begin(), x.end()); + if (result == x.end()) { + return 0; + } + return *result; + } + + template + T max_value(std::vector& x) { + auto result = std::max_element(x.begin(), x.end()); + if (result == x.end()) { + return 0; + } + return *result; + } + + template + std::string protobuf_to_string(T* x) { + std::string output; + google::protobuf::util::JsonPrintOptions opt; + opt.add_whitespace = true; + google::protobuf::util::MessageToJsonString(*x, &output, opt); + return output; + } + + std::vector get_note_durations(std::vector& notes) { + std::vector durations; + for (const auto ¬e : notes) { + double d = note.end() - note.start(); + durations.push_back((int)clip(midigpt_log2(std::max(d / 3., 1e-6)) + 1, 0., 5.)); + } + return durations; + } + + std::tuple av_polyphony_inner(std::vector& notes, int max_tick, midi::TrackFeatures* f) { + int nonzero_count = 0; + double count = 0; + std::vector flat_roll(max_tick, 0); + for (const auto ¬e : notes) { + for (int t = note.start(); t < std::min(note.end(), max_tick - 1); t++) { + if (flat_roll[t] == 0) { + nonzero_count += 1; + } + flat_roll[t]++; + count++; + } + } + + std::vector nz; + for (const auto &x : flat_roll) { + if (x > 0) { + nz.push_back(x); + if (f) { + f->add_polyphony_distribution(x); + } + } + } + + double silence = max_tick - nonzero_count; + + std::vector poly_qs = quantile(nz, { .15,.85 }); + + double min_polyphony = min_value(nz); + double max_polyphony = max_value(nz); + + double av_polyphony = count / std::max(nonzero_count, 1); + double av_silence = silence / std::max(max_tick, 1); + return std::make_tuple(av_polyphony, av_silence, poly_qs[0], poly_qs[1], min_polyphony, max_polyphony); + } + + double note_duration_inner(std::vector& notes) { + double total_diff = 0; + for (const auto ¬e : notes) { + total_diff += (note.end() - note.start()); + } + return total_diff / std::max((int)notes.size(), 1); + } + + // function to get note density value + int get_note_density_target(midi::Track* track, int bin) { + int qindex = track->instrument(); + int tt = track->track_type(); + if (data_structures::is_drum_track(tt)) { + qindex = 128; + } + return enums::DENSITY_QUANTILES[qindex][bin]; + } + + void update_note_density(midi::Piece* x) { + + int track_num = 0; + int num_notes; + for (const auto &track : x->tracks()) { + + // calculate average notes per bar + num_notes = 0; + int bar_num = 0; + std::set valid_bars; + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + if (x->events(event_index).velocity()) { + valid_bars.insert(bar_num); + num_notes++; + } + } + bar_num++; + } + int num_bars = std::max((int)valid_bars.size(), 1); + double av_notes_fp = (double)num_notes / num_bars; + int av_notes = round(av_notes_fp); + + // calculate the density bin + int qindex = track.instrument(); + int bin = 0; + + if (data_structures::is_drum_track(track.track_type())) { + qindex = 128; + } + while (av_notes > enums::DENSITY_QUANTILES[qindex][bin]) { + bin++; + } + + // update protobuf + midi::TrackFeatures* tf = GetTrackFeatures(x, track_num); + tf->set_note_density_v2(bin); + tf->set_note_density_value(av_notes_fp); + track_num++; + + + } + } + + // adding note durations to events + void calculate_note_durations(midi::Piece* p) { + // to start set all durations == 0 + for (int i = 0; i < p->events_size(); i++) { + p->mutable_events(i)->set_internal_duration(0); + } + + for (const auto &track : p->tracks()) { + // pitches to (abs_time, event_index) + std::map> onsets; + int bar_start = 0; + for (const auto &bar : track.bars()) { + for (auto event_id : bar.events()) { + midi::Event e = p->events(event_id); + //data_structures::LOGGER( "PROC EVENT :: " , e.pitch() , " " , e.velocity() , " " , e.time() ); + if (e.velocity() > 0) { + if (data_structures::is_drum_track(track.track_type())) { + // drums always have duration of 1 timestep + p->mutable_events(event_id)->set_internal_duration(1); + } + else { + onsets[e.pitch()] = std::make_tuple(bar_start + e.time(), event_id); + } + } + else { + auto it = onsets.find(e.pitch()); + if (it != onsets.end()) { + int index = std::get<1>(it->second); + int duration = (bar_start + e.time()) - std::get<0>(it->second); + p->mutable_events(index)->set_internal_duration(duration); + } + } + } + // move forward a bar + bar_start += p->resolution() * bar.internal_beat_length(); + } + } + } + + void update_av_polyphony_and_note_duration(midi::Piece* p) { + for (int track_num = 0; track_num < p->tracks_size(); track_num++) { + int max_tick = 0; + std::vector notes = TrackEventsToNotes( + p, track_num, &max_tick); + std::vector durations = get_note_durations(notes); + midi::TrackFeatures* f = GetTrackFeatures(p, track_num); + auto stat = av_polyphony_inner(notes, max_tick, f); + f->set_note_duration(note_duration_inner(notes)); + f->set_av_polyphony(std::get<0>(stat)); + f->set_min_polyphony_q( + std::max(std::min((int)std::get<2>(stat), 10), 1) - 1); + f->set_max_polyphony_q( + std::max(std::min((int)std::get<3>(stat), 10), 1) - 1); + + std::vector dur_qs = quantile(durations, { .15,.85 }); + f->set_min_note_duration_q(dur_qs[0]); + f->set_max_note_duration_q(dur_qs[1]); + + // new hard upper lower limits + f->set_min_polyphony_hard(std::get<4>(stat)); + f->set_max_polyphony_hard(std::get<5>(stat)); + + f->set_min_note_duration_hard(min_value(durations)); + f->set_max_note_duration_hard(max_value(durations)); + + } + } + + std::tuple get_pitch_extents(midi::Piece* x) { + int min_pitch = INT_MAX; + int max_pitch = 0; + for (const auto &track : x->tracks()) { + if (!data_structures::is_drum_track(track.track_type())) { + for (const auto &bar : track.bars()) { + for (const auto &event_index : bar.events()) { + int pitch = x->events(event_index).pitch(); + min_pitch = std::min(pitch, min_pitch); + max_pitch = std::max(pitch, max_pitch); + } + } + } + } + return std::make_pair(min_pitch, max_pitch); + } + + void select_random_segment_indices(midi::Piece* x, int num_bars, int min_tracks, int max_tracks, std::mt19937* engine, std::vector& valid_tracks, int* start) { + UpdateValidSegments(x, num_bars, min_tracks); + + if (x->internal_valid_segments_size() == 0) { + throw std::runtime_error("NO VALID SEGMENTS"); + } + + int index = random_on_range(x->internal_valid_segments_size(), engine); + (*start) = x->internal_valid_segments(index); + for (const auto &track_num : x->internal_valid_tracks_v2(index).tracks()) { + valid_tracks.push_back(track_num); + } + shuffle(valid_tracks.begin(), valid_tracks.end(), *engine); + + // limit the tracks + int ntracks = std::min((int)valid_tracks.size(), max_tracks); + valid_tracks.resize(ntracks); + } + + void prune_tracks(midi::Piece* x, std::vector tracks, std::vector bars) { + + if (x->tracks_size() == 0) { + return; + } + + midi::Piece tmp(*x); + + int num_bars = GetNumBars(x); + bool remove_bars = (int)bars.size() > 0; + x->clear_tracks(); + x->clear_events(); + + std::vector tracks_to_keep; + for (const auto &track_num : tracks) { + if ((track_num >= 0) && (track_num < tmp.tracks_size())) { + tracks_to_keep.push_back(track_num); + } + } + + std::vector bars_to_keep; + for (const auto &bar_num : bars) { + if ((bar_num >= 0) && (bar_num < num_bars)) { + bars_to_keep.push_back(bar_num); + } + } + + for (const auto &track_num : tracks_to_keep) { + const midi::Track track = tmp.tracks(track_num); + midi::Track* t = x->add_tracks(); + t->CopyFrom(track); + if (remove_bars) { + t->clear_bars(); + for (const auto &bar_num : bars_to_keep) { + const midi::Bar bar = track.bars(bar_num); + midi::Bar* b = t->add_bars(); + b->CopyFrom(bar); + b->clear_events(); + for (const auto &event_index : bar.events()) { + b->add_events(x->events_size()); + midi::Event* e = x->add_events(); + e->CopyFrom(tmp.events(event_index)); + } + } + } + } + } + + void select_random_segment(midi::Piece* x, int num_bars, int min_tracks, int max_tracks, std::mt19937* engine) { + int start; + std::vector valid_tracks; + select_random_segment_indices( + x, num_bars, min_tracks, max_tracks, engine, valid_tracks, &start); + std::vector bars = arange(start, start + num_bars, 1); + prune_tracks(x, valid_tracks, bars); + } + + std::set> make_bar_mask(midi::Piece* x, float proportion, std::mt19937* engine) { + int num_tracks = x->tracks_size(); + int num_bars = GetNumBars(x); + int max_filled_bars = (int)round(num_tracks * num_bars * proportion); + int n_fill = random_on_range(max_filled_bars, engine); + std::vector> choices; + for (int track_num = 0; track_num < num_tracks; track_num++) { + for (int bar_num = 0; bar_num < num_bars; bar_num++) { + choices.push_back(std::make_pair(track_num, bar_num)); + } + } + std::set> mask; + shuffle(choices.begin(), choices.end(), *engine); + for (int i = 0; i < n_fill; i++) { + mask.insert(choices[i]); + } + return mask; + } + + std::string get_piece_string(midi::Piece* x) { + std::string output; + google::protobuf::util::JsonPrintOptions opt; + opt.add_whitespace = true; + google::protobuf::util::MessageToJsonString(*x, &output, opt); + return output; + } + + void print_piece(midi::Piece* x) { + data_structures::LOGGER( get_piece_string(x) ); + } + + void print_piece_summary(midi::Piece* x) { + midi::Piece c(*x); + c.clear_events(); + for (int track_num = 0; track_num < c.tracks_size(); track_num++) { + c.mutable_tracks(track_num)->clear_bars(); + } + print_piece(&c); + } + + void reorder_tracks(midi::Piece* x, std::vector track_order) { + int num_tracks = x->tracks_size(); + if (num_tracks != (int)track_order.size()) { + data_structures::LOGGER(data_structures::to_str( num_tracks , " " , track_order.size() )); + throw std::runtime_error("Track order does not match midi::Piece."); + } + for (int track_num = 0; track_num < num_tracks; track_num++) { + GetTrackFeatures(x, track_num)->set_order(track_order[track_num]); + } + std::sort( + x->mutable_tracks()->begin(), + x->mutable_tracks()->end(), + [](const midi::Track& a, const midi::Track& b) { + return a.internal_features(0).order() < b.internal_features(0).order(); + } + ); + } + + template + void string_to_protobuf(std::string& s, T* x) { + google::protobuf::util::JsonParseOptions opt; + opt.ignore_unknown_fields = true; + google::protobuf::util::JsonStringToMessage(s, x, opt); + } + + template + std::string enum_to_string(const T &value) { + const google::protobuf::EnumDescriptor *descriptor = google::protobuf::GetEnumDescriptor(); + return descriptor->FindValueByNumber(value)->name(); + } + + template + T string_to_enum(const std::string &name) { + const google::protobuf::EnumDescriptor *descriptor = google::protobuf::GetEnumDescriptor(); + return static_cast(descriptor->FindValueByName(name)->number()); + } + + template + void print_protobuf(T* x) { + data_structures::LOGGER( protobuf_to_string(x) ); + } + + void pad_piece_with_status(midi::Piece* p, midi::Status* s, int min_bars) { + // add tracks when status references ones that do not exist + for (const auto &track : s->tracks()) { + midi::Track* t = NULL; + if (track.track_id() >= p->tracks_size()) { + t = p->add_tracks(); + t->set_track_type(track.track_type()); + data_structures::LOGGER(data_structures::to_str( "adding track " , track.track_id() )); + } + else { + data_structures::LOGGER(data_structures::to_str( "using track " , track.track_id() )); + t = p->mutable_tracks(track.track_id()); + } + for (int i = t->bars_size(); i < 5; i++) {} // WHAT IS THIS ??? + data_structures::LOGGER(data_structures::to_str( "track " , track.track_id() , " has " , t->bars_size() , " bars" )); + int num_bars = std::max(track.selected_bars_size(), min_bars); + data_structures::LOGGER(data_structures::to_str( "adding " , num_bars , " bars" )); + for (int i = t->bars_size(); i < num_bars; i++) { + data_structures::LOGGER(data_structures::to_str( "adding bar " , i )); + midi::Bar* b = t->add_bars(); + data_structures::LOGGER(data_structures::to_str( "check " , i )); + b->set_internal_beat_length(4); + b->set_ts_numerator(4); + b->set_ts_denominator(4); + } + data_structures::LOGGER( "end" ); + } + } + + + midi::GM_TYPE gm_inst_to_string(int track_type, int instrument) { + return enums::GM_REV[data_structures::is_drum_track(track_type) * 128 + instrument]; + } + + void status_from_piece(midi::Piece *piece, midi::Status *status) { + status->Clear(); + int track_num = 0; + for (const auto &track : piece->tracks()) { + midi::StatusTrack *strack = status->add_tracks(); + strack->set_track_id( track_num ); + strack->set_track_type( track.track_type() ); + strack->set_density( midi::DENSITY_ANY ); // + strack->set_instrument(gm_inst_to_string(track.track_type(),track.instrument())); + strack->set_polyphony_hard_limit( 10 ); + strack->set_temperature( 1. ); + for (int i=0; iadd_selected_bars( false ); + strack->add_bars(); // for bar level controls + } + track_num++; + } + } + + std::string status_from_piece_py(std::string &piece_str) { + midi::Piece p; + midi::Status s; + string_to_protobuf(piece_str, &p); + status_from_piece(&p, &s); + return protobuf_to_string(&s); + } + + midi::HyperParam default_sample_param() { + midi::HyperParam param; + param.set_tracks_per_step( 1 ); + param.set_bars_per_step( 2 ); + param.set_model_dim( 4 ); + param.set_shuffle( true ); + param.set_percentage( 100 ); + param.set_temperature( 1. ); + param.set_batch_size( 1 ); + param.set_max_steps( 0 ); + param.set_verbose( false ); + param.set_polyphony_hard_limit( 5 ); + return param; + } + + std::string default_sample_param_py() { + midi::HyperParam param = default_sample_param(); + return protobuf_to_string(¶m); + } + + std::string prune_tracks_py(std::string json_string, std::vector tracks, std::vector bars) { + midi::Piece p; + string_to_protobuf(json_string, &p); + prune_tracks(&p, tracks, bars); + return protobuf_to_string(&p); + } + +} +// END OF NAMESPACE diff --git a/src/dataset_creation/compression/lz4.c b/src/dataset_creation/compression/lz4.c new file mode 100644 index 0000000000000000000000000000000000000000..8b28de446966443073d4e9d47a7f345681d0b8d4 --- /dev/null +++ b/src/dataset_creation/compression/lz4.c @@ -0,0 +1,2402 @@ +/* + LZ4 - Fast LZ compression algorithm + Copyright (C) 2011-present, Yann Collet. + + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 homepage : http://www.lz4.org + - LZ4 source repository : https://github.com/lz4/lz4 +*/ + +/*-************************************ +* Tuning parameters +**************************************/ +/* + * LZ4_HEAPMODE : + * Select how default compression functions will allocate memory for their hash table, + * in memory stack (0:default, fastest), or in memory heap (1:requires malloc()). + */ +#ifndef LZ4_HEAPMODE +# define LZ4_HEAPMODE 0 +#endif + +/* + * ACCELERATION_DEFAULT : + * Select "acceleration" for LZ4_compress_fast() when parameter value <= 0 + */ +#define ACCELERATION_DEFAULT 1 + + +/*-************************************ +* CPU Feature Detection +**************************************/ +/* LZ4_FORCE_MEMORY_ACCESS + * By default, access to unaligned memory is controlled by `memcpy()`, which is safe and portable. + * Unfortunately, on some target/compiler combinations, the generated assembly is sub-optimal. + * The below switch allow to select different access method for improved performance. + * Method 0 (default) : use `memcpy()`. Safe and portable. + * Method 1 : `__packed` statement. It depends on compiler extension (ie, not portable). + * This method is safe if your compiler supports it, and *generally* as fast or faster than `memcpy`. + * Method 2 : direct access. This method is portable but violate C standard. + * It can generate buggy code on targets which assembly generation depends on alignment. + * But in some circumstances, it's the only known way to get the most performance (ie GCC + ARMv6) + * See https://fastcompression.blogspot.fr/2015/08/accessing-unaligned-memory.html for details. + * Prefer these methods in priority order (0 > 1 > 2) + */ +#ifndef LZ4_FORCE_MEMORY_ACCESS /* can be defined externally */ +# if defined(__GNUC__) && \ + ( defined(__ARM_ARCH_6__) || defined(__ARM_ARCH_6J__) || defined(__ARM_ARCH_6K__) \ + || defined(__ARM_ARCH_6Z__) || defined(__ARM_ARCH_6ZK__) || defined(__ARM_ARCH_6T2__) ) +# define LZ4_FORCE_MEMORY_ACCESS 2 +# elif (defined(__INTEL_COMPILER) && !defined(_WIN32)) || defined(__GNUC__) +# define LZ4_FORCE_MEMORY_ACCESS 1 +# endif +#endif + +/* + * LZ4_FORCE_SW_BITCOUNT + * Define this parameter if your target system or compiler does not support hardware bit count + */ +#if defined(_MSC_VER) && defined(_WIN32_WCE) /* Visual Studio for WinCE doesn't support Hardware bit count */ +# define LZ4_FORCE_SW_BITCOUNT +#endif + + + +/*-************************************ +* Dependency +**************************************/ +/* + * LZ4_SRC_INCLUDED: + * Amalgamation flag, whether lz4.c is included + */ +#ifndef LZ4_SRC_INCLUDED +# define LZ4_SRC_INCLUDED 1 +#endif + +#ifndef LZ4_STATIC_LINKING_ONLY +#define LZ4_STATIC_LINKING_ONLY +#endif + +#ifndef LZ4_DISABLE_DEPRECATE_WARNINGS +#define LZ4_DISABLE_DEPRECATE_WARNINGS /* due to LZ4_decompress_safe_withPrefix64k */ +#endif + +#define LZ4_STATIC_LINKING_ONLY /* LZ4_DISTANCE_MAX */ +#include "../../../include/dataset_creation/compression/lz4.h" +/* see also "memory routines" below */ + + +/*-************************************ +* Compiler Options +**************************************/ +#ifdef _MSC_VER /* Visual Studio */ +# include +# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */ +# pragma warning(disable : 4293) /* disable: C4293: too large shift (32-bits) */ +#endif /* _MSC_VER */ + +#ifndef LZ4_FORCE_INLINE +# ifdef _MSC_VER /* Visual Studio */ +# define LZ4_FORCE_INLINE static __forceinline +# else +# if defined (__cplusplus) || defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* C99 */ +# ifdef __GNUC__ +# define LZ4_FORCE_INLINE static inline __attribute__((always_inline)) +# else +# define LZ4_FORCE_INLINE static inline +# endif +# else +# define LZ4_FORCE_INLINE static +# endif /* __STDC_VERSION__ */ +# endif /* _MSC_VER */ +#endif /* LZ4_FORCE_INLINE */ + +/* LZ4_FORCE_O2_GCC_PPC64LE and LZ4_FORCE_O2_INLINE_GCC_PPC64LE + * gcc on ppc64le generates an unrolled SIMDized loop for LZ4_wildCopy8, + * together with a simple 8-byte copy loop as a fall-back path. + * However, this optimization hurts the decompression speed by >30%, + * because the execution does not go to the optimized loop + * for typical compressible data, and all of the preamble checks + * before going to the fall-back path become useless overhead. + * This optimization happens only with the -O3 flag, and -O2 generates + * a simple 8-byte copy loop. + * With gcc on ppc64le, all of the LZ4_decompress_* and LZ4_wildCopy8 + * functions are annotated with __attribute__((optimize("O2"))), + * and also LZ4_wildCopy8 is forcibly inlined, so that the O2 attribute + * of LZ4_wildCopy8 does not affect the compression speed. + */ +#if defined(__PPC64__) && defined(__LITTLE_ENDIAN__) && defined(__GNUC__) && !defined(__clang__) +# define LZ4_FORCE_O2_GCC_PPC64LE __attribute__((optimize("O2"))) +# define LZ4_FORCE_O2_INLINE_GCC_PPC64LE __attribute__((optimize("O2"))) LZ4_FORCE_INLINE +#else +# define LZ4_FORCE_O2_GCC_PPC64LE +# define LZ4_FORCE_O2_INLINE_GCC_PPC64LE static +#endif + +#if (defined(__GNUC__) && (__GNUC__ >= 3)) || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) || defined(__clang__) +# define expect(expr,value) (__builtin_expect ((expr),(value)) ) +#else +# define expect(expr,value) (expr) +#endif + +#ifndef likely +#define likely(expr) expect((expr) != 0, 1) +#endif +#ifndef unlikely +#define unlikely(expr) expect((expr) != 0, 0) +#endif + + +/*-************************************ +* Memory routines +**************************************/ +#include /* malloc, calloc, free */ +#define ALLOC(s) malloc(s) +#define ALLOC_AND_ZERO(s) calloc(1,s) +#define FREEMEM(p) free(p) +#include /* memset, memcpy */ +#define MEM_INIT(p,v,s) memset((p),(v),(s)) + + +/*-************************************ +* Common Constants +**************************************/ +#define MINMATCH 4 + +#define WILDCOPYLENGTH 8 +#define LASTLITERALS 5 /* see ../doc/lz4_Block_format.md#parsing-restrictions */ +#define MFLIMIT 12 /* see ../doc/lz4_Block_format.md#parsing-restrictions */ +#define MATCH_SAFEGUARD_DISTANCE ((2*WILDCOPYLENGTH) - MINMATCH) /* ensure it's possible to write 2 x wildcopyLength without overflowing output buffer */ +#define FASTLOOP_SAFE_DISTANCE 64 +static const int LZ4_minLength = (MFLIMIT+1); + +#define KB *(1 <<10) +#define MB *(1 <<20) +#define GB *(1U<<30) + +#define LZ4_DISTANCE_ABSOLUTE_MAX 65535 +#if (LZ4_DISTANCE_MAX > LZ4_DISTANCE_ABSOLUTE_MAX) /* max supported by LZ4 format */ +# error "LZ4_DISTANCE_MAX is too big : must be <= 65535" +#endif + +#define ML_BITS 4 +#define ML_MASK ((1U<=1) +# include +#else +# ifndef assert +# define assert(condition) ((void)0) +# endif +#endif + +#define LZ4_STATIC_ASSERT(c) { enum { LZ4_static_assert = 1/(int)(!!(c)) }; } /* use after variable declarations */ + +#if defined(LZ4_DEBUG) && (LZ4_DEBUG>=2) +# include + static int g_debuglog_enable = 1; +# define DEBUGLOG(l, ...) { \ + if ((g_debuglog_enable) && (l<=LZ4_DEBUG)) { \ + fprintf(stderr, __FILE__ ": "); \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, " \n"); \ + } } +#else +# define DEBUGLOG(l, ...) {} /* disabled */ +#endif + + +/*-************************************ +* Types +**************************************/ +#if defined(__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +# include + typedef uint8_t BYTE; + typedef uint16_t U16; + typedef uint32_t U32; + typedef int32_t S32; + typedef uint64_t U64; + typedef uintptr_t uptrval; +#else +# include +# if UINT_MAX != 4294967295UL +# error "LZ4 code (when not C++ or C99) assumes that sizeof(int) == 4" +# endif + typedef unsigned char BYTE; + typedef unsigned short U16; + typedef unsigned int U32; + typedef signed int S32; + typedef unsigned long long U64; + typedef size_t uptrval; /* generally true, except OpenVMS-64 */ +#endif + +#if defined(__x86_64__) + typedef U64 reg_t; /* 64-bits in x32 mode */ +#else + typedef size_t reg_t; /* 32-bits in x32 mode */ +#endif + +typedef enum { + notLimited = 0, + limitedOutput = 1, + fillOutput = 2 +} limitedOutput_directive; + + +/*-************************************ +* Reading and writing into memory +**************************************/ +static unsigned LZ4_isLittleEndian(void) +{ + const union { U32 u; BYTE c[4]; } one = { 1 }; /* don't use static : performance detrimental */ + return one.c[0]; +} + + +#if defined(LZ4_FORCE_MEMORY_ACCESS) && (LZ4_FORCE_MEMORY_ACCESS==2) +/* lie to the compiler about data alignment; use with caution */ + +static U16 LZ4_read16(const void* memPtr) { return *(const U16*) memPtr; } +static U32 LZ4_read32(const void* memPtr) { return *(const U32*) memPtr; } +static reg_t LZ4_read_ARCH(const void* memPtr) { return *(const reg_t*) memPtr; } + +static void LZ4_write16(void* memPtr, U16 value) { *(U16*)memPtr = value; } +static void LZ4_write32(void* memPtr, U32 value) { *(U32*)memPtr = value; } + +#elif defined(LZ4_FORCE_MEMORY_ACCESS) && (LZ4_FORCE_MEMORY_ACCESS==1) + +/* __pack instructions are safer, but compiler specific, hence potentially problematic for some compilers */ +/* currently only defined for gcc and icc */ +typedef union { U16 u16; U32 u32; reg_t uArch; } __attribute__((packed)) unalign; + +static U16 LZ4_read16(const void* ptr) { return ((const unalign*)ptr)->u16; } +static U32 LZ4_read32(const void* ptr) { return ((const unalign*)ptr)->u32; } +static reg_t LZ4_read_ARCH(const void* ptr) { return ((const unalign*)ptr)->uArch; } + +static void LZ4_write16(void* memPtr, U16 value) { ((unalign*)memPtr)->u16 = value; } +static void LZ4_write32(void* memPtr, U32 value) { ((unalign*)memPtr)->u32 = value; } + +#else /* safe and portable access using memcpy() */ + +static U16 LZ4_read16(const void* memPtr) +{ + U16 val; memcpy(&val, memPtr, sizeof(val)); return val; +} + +static U32 LZ4_read32(const void* memPtr) +{ + U32 val; memcpy(&val, memPtr, sizeof(val)); return val; +} + +static reg_t LZ4_read_ARCH(const void* memPtr) +{ + reg_t val; memcpy(&val, memPtr, sizeof(val)); return val; +} + +static void LZ4_write16(void* memPtr, U16 value) +{ + memcpy(memPtr, &value, sizeof(value)); +} + +static void LZ4_write32(void* memPtr, U32 value) +{ + memcpy(memPtr, &value, sizeof(value)); +} + +#endif /* LZ4_FORCE_MEMORY_ACCESS */ + + +static U16 LZ4_readLE16(const void* memPtr) +{ + if (LZ4_isLittleEndian()) { + return LZ4_read16(memPtr); + } else { + const BYTE* p = (const BYTE*)memPtr; + return (U16)((U16)p[0] + (p[1]<<8)); + } +} + +static void LZ4_writeLE16(void* memPtr, U16 value) +{ + if (LZ4_isLittleEndian()) { + LZ4_write16(memPtr, value); + } else { + BYTE* p = (BYTE*)memPtr; + p[0] = (BYTE) value; + p[1] = (BYTE)(value>>8); + } +} + +/* customized variant of memcpy, which can overwrite up to 8 bytes beyond dstEnd */ +LZ4_FORCE_O2_INLINE_GCC_PPC64LE +void LZ4_wildCopy8(void* dstPtr, const void* srcPtr, void* dstEnd) +{ + BYTE* d = (BYTE*)dstPtr; + const BYTE* s = (const BYTE*)srcPtr; + BYTE* const e = (BYTE*)dstEnd; + + do { memcpy(d,s,8); d+=8; s+=8; } while (d= 16. */ +LZ4_FORCE_O2_INLINE_GCC_PPC64LE void +LZ4_wildCopy32(void* dstPtr, const void* srcPtr, void* dstEnd) +{ + BYTE* d = (BYTE*)dstPtr; + const BYTE* s = (const BYTE*)srcPtr; + BYTE* const e = (BYTE*)dstEnd; + + do { memcpy(d,s,16); memcpy(d+16,s+16,16); d+=32; s+=32; } while (d= dstPtr + MINMATCH + * - there is at least 8 bytes available to write after dstEnd */ +LZ4_FORCE_O2_INLINE_GCC_PPC64LE void +LZ4_memcpy_using_offset(BYTE* dstPtr, const BYTE* srcPtr, BYTE* dstEnd, const size_t offset) +{ + BYTE v[8]; + + assert(dstEnd >= dstPtr + MINMATCH); + LZ4_write32(dstPtr, 0); /* silence an msan warning when offset==0 */ + + switch(offset) { + case 1: + memset(v, *srcPtr, 8); + break; + case 2: + memcpy(v, srcPtr, 2); + memcpy(&v[2], srcPtr, 2); + memcpy(&v[4], &v[0], 4); + break; + case 4: + memcpy(v, srcPtr, 4); + memcpy(&v[4], srcPtr, 4); + break; + default: + LZ4_memcpy_using_offset_base(dstPtr, srcPtr, dstEnd, offset); + return; + } + + memcpy(dstPtr, v, 8); + dstPtr += 8; + while (dstPtr < dstEnd) { + memcpy(dstPtr, v, 8); + dstPtr += 8; + } +} +#endif + + +/*-************************************ +* Common functions +**************************************/ +static unsigned LZ4_NbCommonBytes (reg_t val) +{ + if (LZ4_isLittleEndian()) { + if (sizeof(val)==8) { +# if defined(_MSC_VER) && defined(_WIN64) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanForward64( &r, (U64)val ); + return (int)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_ctzll((U64)val) >> 3; +# else + static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, + 0, 3, 1, 3, 1, 4, 2, 7, + 0, 2, 3, 6, 1, 5, 3, 5, + 1, 3, 4, 4, 2, 5, 6, 7, + 7, 0, 1, 2, 3, 3, 4, 6, + 2, 6, 5, 5, 3, 4, 5, 6, + 7, 1, 2, 4, 6, 4, 4, 5, + 7, 2, 6, 5, 7, 6, 7, 7 }; + return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; +# endif + } else /* 32 bits */ { +# if defined(_MSC_VER) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r; + _BitScanForward( &r, (U32)val ); + return (int)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_ctz((U32)val) >> 3; +# else + static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, + 3, 2, 2, 1, 3, 2, 0, 1, + 3, 3, 1, 2, 2, 2, 2, 0, + 3, 1, 2, 0, 1, 0, 1, 1 }; + return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; +# endif + } + } else /* Big Endian CPU */ { + if (sizeof(val)==8) { /* 64-bits */ +# if defined(_MSC_VER) && defined(_WIN64) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanReverse64( &r, val ); + return (unsigned)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_clzll((U64)val) >> 3; +# else + static const U32 by32 = sizeof(val)*4; /* 32 on 64 bits (goal), 16 on 32 bits. + Just to avoid some static analyzer complaining about shift by 32 on 32-bits target. + Note that this code path is never triggered in 32-bits mode. */ + unsigned r; + if (!(val>>by32)) { r=4; } else { r=0; val>>=by32; } + if (!(val>>16)) { r+=2; val>>=8; } else { val>>=24; } + r += (!val); + return r; +# endif + } else /* 32 bits */ { +# if defined(_MSC_VER) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanReverse( &r, (unsigned long)val ); + return (unsigned)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_clz((U32)val) >> 3; +# else + unsigned r; + if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; } + r += (!val); + return r; +# endif + } + } +} + +#define STEPSIZE sizeof(reg_t) +LZ4_FORCE_INLINE +unsigned LZ4_count(const BYTE* pIn, const BYTE* pMatch, const BYTE* pInLimit) +{ + const BYTE* const pStart = pIn; + + if (likely(pIn < pInLimit-(STEPSIZE-1))) { + reg_t const diff = LZ4_read_ARCH(pMatch) ^ LZ4_read_ARCH(pIn); + if (!diff) { + pIn+=STEPSIZE; pMatch+=STEPSIZE; + } else { + return LZ4_NbCommonBytes(diff); + } } + + while (likely(pIn < pInLimit-(STEPSIZE-1))) { + reg_t const diff = LZ4_read_ARCH(pMatch) ^ LZ4_read_ARCH(pIn); + if (!diff) { pIn+=STEPSIZE; pMatch+=STEPSIZE; continue; } + pIn += LZ4_NbCommonBytes(diff); + return (unsigned)(pIn - pStart); + } + + if ((STEPSIZE==8) && (pIn<(pInLimit-3)) && (LZ4_read32(pMatch) == LZ4_read32(pIn))) { pIn+=4; pMatch+=4; } + if ((pIn<(pInLimit-1)) && (LZ4_read16(pMatch) == LZ4_read16(pIn))) { pIn+=2; pMatch+=2; } + if ((pIn compression run slower on incompressible data */ + + +/*-************************************ +* Local Structures and types +**************************************/ +typedef enum { clearedTable = 0, byPtr, byU32, byU16 } tableType_t; + +/** + * This enum distinguishes several different modes of accessing previous + * content in the stream. + * + * - noDict : There is no preceding content. + * - withPrefix64k : Table entries up to ctx->dictSize before the current blob + * blob being compressed are valid and refer to the preceding + * content (of length ctx->dictSize), which is available + * contiguously preceding in memory the content currently + * being compressed. + * - usingExtDict : Like withPrefix64k, but the preceding content is somewhere + * else in memory, starting at ctx->dictionary with length + * ctx->dictSize. + * - usingDictCtx : Like usingExtDict, but everything concerning the preceding + * content is in a separate context, pointed to by + * ctx->dictCtx. ctx->dictionary, ctx->dictSize, and table + * entries in the current context that refer to positions + * preceding the beginning of the current compression are + * ignored. Instead, ctx->dictCtx->dictionary and ctx->dictCtx + * ->dictSize describe the location and size of the preceding + * content, and matches are found by looking in the ctx + * ->dictCtx->hashTable. + */ +typedef enum { noDict = 0, withPrefix64k, usingExtDict, usingDictCtx } dict_directive; +typedef enum { noDictIssue = 0, dictSmall } dictIssue_directive; + + +/*-************************************ +* Local Utils +**************************************/ +int LZ4_versionNumber (void) { return LZ4_VERSION_NUMBER; } +const char* LZ4_versionString(void) { return LZ4_VERSION_STRING; } +int LZ4_compressBound(int isize) { return LZ4_COMPRESSBOUND(isize); } +int LZ4_sizeofState() { return LZ4_STREAMSIZE; } + + +/*-************************************ +* Internal Definitions used in Tests +**************************************/ +#if defined (__cplusplus) +extern "C" { +#endif + +int LZ4_compress_forceExtDict (LZ4_stream_t* LZ4_dict, const char* source, char* dest, int srcSize); + +int LZ4_decompress_safe_forceExtDict(const char* source, char* dest, + int compressedSize, int maxOutputSize, + const void* dictStart, size_t dictSize); + +#if defined (__cplusplus) +} +#endif + +/*-****************************** +* Compression functions +********************************/ +LZ4_FORCE_INLINE U32 LZ4_hash4(U32 sequence, tableType_t const tableType) +{ + if (tableType == byU16) + return ((sequence * 2654435761U) >> ((MINMATCH*8)-(LZ4_HASHLOG+1))); + else + return ((sequence * 2654435761U) >> ((MINMATCH*8)-LZ4_HASHLOG)); +} + +LZ4_FORCE_INLINE U32 LZ4_hash5(U64 sequence, tableType_t const tableType) +{ + const U32 hashLog = (tableType == byU16) ? LZ4_HASHLOG+1 : LZ4_HASHLOG; + if (LZ4_isLittleEndian()) { + const U64 prime5bytes = 889523592379ULL; + return (U32)(((sequence << 24) * prime5bytes) >> (64 - hashLog)); + } else { + const U64 prime8bytes = 11400714785074694791ULL; + return (U32)(((sequence >> 24) * prime8bytes) >> (64 - hashLog)); + } +} + +LZ4_FORCE_INLINE U32 LZ4_hashPosition(const void* const p, tableType_t const tableType) +{ + if ((sizeof(reg_t)==8) && (tableType != byU16)) return LZ4_hash5(LZ4_read_ARCH(p), tableType); + return LZ4_hash4(LZ4_read32(p), tableType); +} + +LZ4_FORCE_INLINE void LZ4_clearHash(U32 h, void* tableBase, tableType_t const tableType) +{ + switch (tableType) + { + default: /* fallthrough */ + case clearedTable: { /* illegal! */ assert(0); return; } + case byPtr: { const BYTE** hashTable = (const BYTE**)tableBase; hashTable[h] = NULL; return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = 0; return; } + case byU16: { U16* hashTable = (U16*) tableBase; hashTable[h] = 0; return; } + } +} + +LZ4_FORCE_INLINE void LZ4_putIndexOnHash(U32 idx, U32 h, void* tableBase, tableType_t const tableType) +{ + switch (tableType) + { + default: /* fallthrough */ + case clearedTable: /* fallthrough */ + case byPtr: { /* illegal! */ assert(0); return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = idx; return; } + case byU16: { U16* hashTable = (U16*) tableBase; assert(idx < 65536); hashTable[h] = (U16)idx; return; } + } +} + +LZ4_FORCE_INLINE void LZ4_putPositionOnHash(const BYTE* p, U32 h, + void* tableBase, tableType_t const tableType, + const BYTE* srcBase) +{ + switch (tableType) + { + case clearedTable: { /* illegal! */ assert(0); return; } + case byPtr: { const BYTE** hashTable = (const BYTE**)tableBase; hashTable[h] = p; return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = (U32)(p-srcBase); return; } + case byU16: { U16* hashTable = (U16*) tableBase; hashTable[h] = (U16)(p-srcBase); return; } + } +} + +LZ4_FORCE_INLINE void LZ4_putPosition(const BYTE* p, void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + U32 const h = LZ4_hashPosition(p, tableType); + LZ4_putPositionOnHash(p, h, tableBase, tableType, srcBase); +} + +/* LZ4_getIndexOnHash() : + * Index of match position registered in hash table. + * hash position must be calculated by using base+index, or dictBase+index. + * Assumption 1 : only valid if tableType == byU32 or byU16. + * Assumption 2 : h is presumed valid (within limits of hash table) + */ +LZ4_FORCE_INLINE U32 LZ4_getIndexOnHash(U32 h, const void* tableBase, tableType_t tableType) +{ + LZ4_STATIC_ASSERT(LZ4_MEMORY_USAGE > 2); + if (tableType == byU32) { + const U32* const hashTable = (const U32*) tableBase; + assert(h < (1U << (LZ4_MEMORY_USAGE-2))); + return hashTable[h]; + } + if (tableType == byU16) { + const U16* const hashTable = (const U16*) tableBase; + assert(h < (1U << (LZ4_MEMORY_USAGE-1))); + return hashTable[h]; + } + assert(0); return 0; /* forbidden case */ +} + +static const BYTE* LZ4_getPositionOnHash(U32 h, const void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + if (tableType == byPtr) { const BYTE* const* hashTable = (const BYTE* const*) tableBase; return hashTable[h]; } + if (tableType == byU32) { const U32* const hashTable = (const U32*) tableBase; return hashTable[h] + srcBase; } + { const U16* const hashTable = (const U16*) tableBase; return hashTable[h] + srcBase; } /* default, to ensure a return */ +} + +LZ4_FORCE_INLINE const BYTE* +LZ4_getPosition(const BYTE* p, + const void* tableBase, tableType_t tableType, + const BYTE* srcBase) +{ + U32 const h = LZ4_hashPosition(p, tableType); + return LZ4_getPositionOnHash(h, tableBase, tableType, srcBase); +} + +LZ4_FORCE_INLINE void +LZ4_prepareTable(LZ4_stream_t_internal* const cctx, + const int inputSize, + const tableType_t tableType) { + /* If compression failed during the previous step, then the context + * is marked as dirty, therefore, it has to be fully reset. + */ + if (cctx->dirty) { + DEBUGLOG(5, "LZ4_prepareTable: Full reset for %p", cctx); + MEM_INIT(cctx, 0, sizeof(LZ4_stream_t_internal)); + return; + } + + /* If the table hasn't been used, it's guaranteed to be zeroed out, and is + * therefore safe to use no matter what mode we're in. Otherwise, we figure + * out if it's safe to leave as is or whether it needs to be reset. + */ + if (cctx->tableType != clearedTable) { + assert(inputSize >= 0); + if (cctx->tableType != tableType + || ((tableType == byU16) && cctx->currentOffset + (unsigned)inputSize >= 0xFFFFU) + || ((tableType == byU32) && cctx->currentOffset > 1 GB) + || tableType == byPtr + || inputSize >= 4 KB) + { + DEBUGLOG(4, "LZ4_prepareTable: Resetting table in %p", cctx); + MEM_INIT(cctx->hashTable, 0, LZ4_HASHTABLESIZE); + cctx->currentOffset = 0; + cctx->tableType = clearedTable; + } else { + DEBUGLOG(4, "LZ4_prepareTable: Re-use hash table (no reset)"); + } + } + + /* Adding a gap, so all previous entries are > LZ4_DISTANCE_MAX back, is faster + * than compressing without a gap. However, compressing with + * currentOffset == 0 is faster still, so we preserve that case. + */ + if (cctx->currentOffset != 0 && tableType == byU32) { + DEBUGLOG(5, "LZ4_prepareTable: adding 64KB to currentOffset"); + cctx->currentOffset += 64 KB; + } + + /* Finally, clear history */ + cctx->dictCtx = NULL; + cctx->dictionary = NULL; + cctx->dictSize = 0; +} + +/** LZ4_compress_generic() : + inlined, to ensure branches are decided at compilation time */ +LZ4_FORCE_INLINE int LZ4_compress_generic( + LZ4_stream_t_internal* const cctx, + const char* const source, + char* const dest, + const int inputSize, + int *inputConsumed, /* only written when outputDirective == fillOutput */ + const int maxOutputSize, + const limitedOutput_directive outputDirective, + const tableType_t tableType, + const dict_directive dictDirective, + const dictIssue_directive dictIssue, + const int acceleration) +{ + int result; + const BYTE* ip = (const BYTE*) source; + + U32 const startIndex = cctx->currentOffset; + const BYTE* base = (const BYTE*) source - startIndex; + const BYTE* lowLimit; + + const LZ4_stream_t_internal* dictCtx = (const LZ4_stream_t_internal*) cctx->dictCtx; + const BYTE* const dictionary = + dictDirective == usingDictCtx ? dictCtx->dictionary : cctx->dictionary; + const U32 dictSize = + dictDirective == usingDictCtx ? dictCtx->dictSize : cctx->dictSize; + const U32 dictDelta = (dictDirective == usingDictCtx) ? startIndex - dictCtx->currentOffset : 0; /* make indexes in dictCtx comparable with index in current context */ + + int const maybe_extMem = (dictDirective == usingExtDict) || (dictDirective == usingDictCtx); + U32 const prefixIdxLimit = startIndex - dictSize; /* used when dictDirective == dictSmall */ + const BYTE* const dictEnd = dictionary + dictSize; + const BYTE* anchor = (const BYTE*) source; + const BYTE* const iend = ip + inputSize; + const BYTE* const mflimitPlusOne = iend - MFLIMIT + 1; + const BYTE* const matchlimit = iend - LASTLITERALS; + + /* the dictCtx currentOffset is indexed on the start of the dictionary, + * while a dictionary in the current context precedes the currentOffset */ + const BYTE* dictBase = (dictDirective == usingDictCtx) ? + dictionary + dictSize - dictCtx->currentOffset : + dictionary + dictSize - startIndex; + + BYTE* op = (BYTE*) dest; + BYTE* const olimit = op + maxOutputSize; + + U32 offset = 0; + U32 forwardH; + + DEBUGLOG(5, "LZ4_compress_generic: srcSize=%i, tableType=%u", inputSize, tableType); + /* If init conditions are not met, we don't have to mark stream + * as having dirty context, since no action was taken yet */ + if (outputDirective == fillOutput && maxOutputSize < 1) { return 0; } /* Impossible to store anything */ + if ((U32)inputSize > (U32)LZ4_MAX_INPUT_SIZE) { return 0; } /* Unsupported inputSize, too large (or negative) */ + if ((tableType == byU16) && (inputSize>=LZ4_64Klimit)) { return 0; } /* Size too large (not within 64K limit) */ + if (tableType==byPtr) assert(dictDirective==noDict); /* only supported use case with byPtr */ + assert(acceleration >= 1); + + lowLimit = (const BYTE*)source - (dictDirective == withPrefix64k ? dictSize : 0); + + /* Update context state */ + if (dictDirective == usingDictCtx) { + /* Subsequent linked blocks can't use the dictionary. */ + /* Instead, they use the block we just compressed. */ + cctx->dictCtx = NULL; + cctx->dictSize = (U32)inputSize; + } else { + cctx->dictSize += (U32)inputSize; + } + cctx->currentOffset += (U32)inputSize; + cctx->tableType = (U16)tableType; + + if (inputSizehashTable, tableType, base); + ip++; forwardH = LZ4_hashPosition(ip, tableType); + + /* Main Loop */ + for ( ; ; ) { + const BYTE* match; + BYTE* token; + const BYTE* filledIp; + + /* Find a match */ + if (tableType == byPtr) { + const BYTE* forwardIp = ip; + int step = 1; + int searchMatchNb = acceleration << LZ4_skipTrigger; + do { + U32 const h = forwardH; + ip = forwardIp; + forwardIp += step; + step = (searchMatchNb++ >> LZ4_skipTrigger); + + if (unlikely(forwardIp > mflimitPlusOne)) goto _last_literals; + assert(ip < mflimitPlusOne); + + match = LZ4_getPositionOnHash(h, cctx->hashTable, tableType, base); + forwardH = LZ4_hashPosition(forwardIp, tableType); + LZ4_putPositionOnHash(ip, h, cctx->hashTable, tableType, base); + + } while ( (match+LZ4_DISTANCE_MAX < ip) + || (LZ4_read32(match) != LZ4_read32(ip)) ); + + } else { /* byU32, byU16 */ + + const BYTE* forwardIp = ip; + int step = 1; + int searchMatchNb = acceleration << LZ4_skipTrigger; + do { + U32 const h = forwardH; + U32 const current = (U32)(forwardIp - base); + U32 matchIndex = LZ4_getIndexOnHash(h, cctx->hashTable, tableType); + assert(matchIndex <= current); + assert(forwardIp - base < (ptrdiff_t)(2 GB - 1)); + ip = forwardIp; + forwardIp += step; + step = (searchMatchNb++ >> LZ4_skipTrigger); + + if (unlikely(forwardIp > mflimitPlusOne)) goto _last_literals; + assert(ip < mflimitPlusOne); + + if (dictDirective == usingDictCtx) { + if (matchIndex < startIndex) { + /* there was no match, try the dictionary */ + assert(tableType == byU32); + matchIndex = LZ4_getIndexOnHash(h, dictCtx->hashTable, byU32); + match = dictBase + matchIndex; + matchIndex += dictDelta; /* make dictCtx index comparable with current context */ + lowLimit = dictionary; + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; + } + } else if (dictDirective==usingExtDict) { + if (matchIndex < startIndex) { + DEBUGLOG(7, "extDict candidate: matchIndex=%5u < startIndex=%5u", matchIndex, startIndex); + assert(startIndex - matchIndex >= MINMATCH); + match = dictBase + matchIndex; + lowLimit = dictionary; + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; + } + } else { /* single continuous memory segment */ + match = base + matchIndex; + } + forwardH = LZ4_hashPosition(forwardIp, tableType); + LZ4_putIndexOnHash(current, h, cctx->hashTable, tableType); + + DEBUGLOG(7, "candidate at pos=%u (offset=%u \n", matchIndex, current - matchIndex); + if ((dictIssue == dictSmall) && (matchIndex < prefixIdxLimit)) { continue; } /* match outside of valid area */ + assert(matchIndex < current); + if ( ((tableType != byU16) || (LZ4_DISTANCE_MAX < LZ4_DISTANCE_ABSOLUTE_MAX)) + && (matchIndex+LZ4_DISTANCE_MAX < current)) { + continue; + } /* too far */ + assert((current - matchIndex) <= LZ4_DISTANCE_MAX); /* match now expected within distance */ + + if (LZ4_read32(match) == LZ4_read32(ip)) { + if (maybe_extMem) offset = current - matchIndex; + break; /* match found */ + } + + } while(1); + } + + /* Catch up */ + filledIp = ip; + while (((ip>anchor) & (match > lowLimit)) && (unlikely(ip[-1]==match[-1]))) { ip--; match--; } + + /* Encode Literals */ + { unsigned const litLength = (unsigned)(ip - anchor); + token = op++; + if ((outputDirective == limitedOutput) && /* Check output buffer overflow */ + (unlikely(op + litLength + (2 + 1 + LASTLITERALS) + (litLength/255) > olimit)) ) { + return 0; /* cannot compress within `dst` budget. Stored indexes in hash table are nonetheless fine */ + } + if ((outputDirective == fillOutput) && + (unlikely(op + (litLength+240)/255 /* litlen */ + litLength /* literals */ + 2 /* offset */ + 1 /* token */ + MFLIMIT - MINMATCH /* min last literals so last match is <= end - MFLIMIT */ > olimit))) { + op--; + goto _last_literals; + } + if (litLength >= RUN_MASK) { + int len = (int)(litLength - RUN_MASK); + *token = (RUN_MASK<= 255 ; len-=255) *op++ = 255; + *op++ = (BYTE)len; + } + else *token = (BYTE)(litLength< olimit)) { + /* the match was too close to the end, rewind and go to last literals */ + op = token; + goto _last_literals; + } + + /* Encode Offset */ + if (maybe_extMem) { /* static test */ + DEBUGLOG(6, " with offset=%u (ext if > %i)", offset, (int)(ip - (const BYTE*)source)); + assert(offset <= LZ4_DISTANCE_MAX && offset > 0); + LZ4_writeLE16(op, (U16)offset); op+=2; + } else { + DEBUGLOG(6, " with offset=%u (same segment)", (U32)(ip - match)); + assert(ip-match <= LZ4_DISTANCE_MAX); + LZ4_writeLE16(op, (U16)(ip - match)); op+=2; + } + + /* Encode MatchLength */ + { unsigned matchCode; + + if ( (dictDirective==usingExtDict || dictDirective==usingDictCtx) + && (lowLimit==dictionary) /* match within extDict */ ) { + const BYTE* limit = ip + (dictEnd-match); + assert(dictEnd > match); + if (limit > matchlimit) limit = matchlimit; + matchCode = LZ4_count(ip+MINMATCH, match+MINMATCH, limit); + ip += (size_t)matchCode + MINMATCH; + if (ip==limit) { + unsigned const more = LZ4_count(limit, (const BYTE*)source, matchlimit); + matchCode += more; + ip += more; + } + DEBUGLOG(6, " with matchLength=%u starting in extDict", matchCode+MINMATCH); + } else { + matchCode = LZ4_count(ip+MINMATCH, match+MINMATCH, matchlimit); + ip += (size_t)matchCode + MINMATCH; + DEBUGLOG(6, " with matchLength=%u", matchCode+MINMATCH); + } + + if ((outputDirective) && /* Check output buffer overflow */ + (unlikely(op + (1 + LASTLITERALS) + (matchCode+240)/255 > olimit)) ) { + if (outputDirective == fillOutput) { + /* Match description too long : reduce it */ + U32 newMatchCode = 15 /* in token */ - 1 /* to avoid needing a zero byte */ + ((U32)(olimit - op) - 1 - LASTLITERALS) * 255; + ip -= matchCode - newMatchCode; + assert(newMatchCode < matchCode); + matchCode = newMatchCode; + if (unlikely(ip <= filledIp)) { + /* We have already filled up to filledIp so if ip ends up less than filledIp + * we have positions in the hash table beyond the current position. This is + * a problem if we reuse the hash table. So we have to remove these positions + * from the hash table. + */ + const BYTE* ptr; + DEBUGLOG(5, "Clearing %u positions", (U32)(filledIp - ip)); + for (ptr = ip; ptr <= filledIp; ++ptr) { + U32 const h = LZ4_hashPosition(ptr, tableType); + LZ4_clearHash(h, cctx->hashTable, tableType); + } + } + } else { + assert(outputDirective == limitedOutput); + return 0; /* cannot compress within `dst` budget. Stored indexes in hash table are nonetheless fine */ + } + } + if (matchCode >= ML_MASK) { + *token += ML_MASK; + matchCode -= ML_MASK; + LZ4_write32(op, 0xFFFFFFFF); + while (matchCode >= 4*255) { + op+=4; + LZ4_write32(op, 0xFFFFFFFF); + matchCode -= 4*255; + } + op += matchCode / 255; + *op++ = (BYTE)(matchCode % 255); + } else + *token += (BYTE)(matchCode); + } + /* Ensure we have enough space for the last literals. */ + assert(!(outputDirective == fillOutput && op + 1 + LASTLITERALS > olimit)); + + anchor = ip; + + /* Test end of chunk */ + if (ip >= mflimitPlusOne) break; + + /* Fill table */ + LZ4_putPosition(ip-2, cctx->hashTable, tableType, base); + + /* Test next position */ + if (tableType == byPtr) { + + match = LZ4_getPosition(ip, cctx->hashTable, tableType, base); + LZ4_putPosition(ip, cctx->hashTable, tableType, base); + if ( (match+LZ4_DISTANCE_MAX >= ip) + && (LZ4_read32(match) == LZ4_read32(ip)) ) + { token=op++; *token=0; goto _next_match; } + + } else { /* byU32, byU16 */ + + U32 const h = LZ4_hashPosition(ip, tableType); + U32 const current = (U32)(ip-base); + U32 matchIndex = LZ4_getIndexOnHash(h, cctx->hashTable, tableType); + assert(matchIndex < current); + if (dictDirective == usingDictCtx) { + if (matchIndex < startIndex) { + /* there was no match, try the dictionary */ + matchIndex = LZ4_getIndexOnHash(h, dictCtx->hashTable, byU32); + match = dictBase + matchIndex; + lowLimit = dictionary; /* required for match length counter */ + matchIndex += dictDelta; + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; /* required for match length counter */ + } + } else if (dictDirective==usingExtDict) { + if (matchIndex < startIndex) { + match = dictBase + matchIndex; + lowLimit = dictionary; /* required for match length counter */ + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; /* required for match length counter */ + } + } else { /* single memory segment */ + match = base + matchIndex; + } + LZ4_putIndexOnHash(current, h, cctx->hashTable, tableType); + assert(matchIndex < current); + if ( ((dictIssue==dictSmall) ? (matchIndex >= prefixIdxLimit) : 1) + && (((tableType==byU16) && (LZ4_DISTANCE_MAX == LZ4_DISTANCE_ABSOLUTE_MAX)) ? 1 : (matchIndex+LZ4_DISTANCE_MAX >= current)) + && (LZ4_read32(match) == LZ4_read32(ip)) ) { + token=op++; + *token=0; + if (maybe_extMem) offset = current - matchIndex; + DEBUGLOG(6, "seq.start:%i, literals=%u, match.start:%i", + (int)(anchor-(const BYTE*)source), 0, (int)(ip-(const BYTE*)source)); + goto _next_match; + } + } + + /* Prepare next loop */ + forwardH = LZ4_hashPosition(++ip, tableType); + + } + +_last_literals: + /* Encode Last Literals */ + { size_t lastRun = (size_t)(iend - anchor); + if ( (outputDirective) && /* Check output buffer overflow */ + (op + lastRun + 1 + ((lastRun+255-RUN_MASK)/255) > olimit)) { + if (outputDirective == fillOutput) { + /* adapt lastRun to fill 'dst' */ + assert(olimit >= op); + lastRun = (size_t)(olimit-op) - 1; + lastRun -= (lastRun+240)/255; + } else { + assert(outputDirective == limitedOutput); + return 0; /* cannot compress within `dst` budget. Stored indexes in hash table are nonetheless fine */ + } + } + if (lastRun >= RUN_MASK) { + size_t accumulator = lastRun - RUN_MASK; + *op++ = RUN_MASK << ML_BITS; + for(; accumulator >= 255 ; accumulator-=255) *op++ = 255; + *op++ = (BYTE) accumulator; + } else { + *op++ = (BYTE)(lastRun< 0); + return result; +} + + +int LZ4_compress_fast_extState(void* state, const char* source, char* dest, int inputSize, int maxOutputSize, int acceleration) +{ + LZ4_stream_t_internal* const ctx = & LZ4_initStream(state, sizeof(LZ4_stream_t)) -> internal_donotuse; + assert(ctx != NULL); + if (acceleration < 1) acceleration = ACCELERATION_DEFAULT; + if (maxOutputSize >= LZ4_compressBound(inputSize)) { + if (inputSize < LZ4_64Klimit) { + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, 0, notLimited, byU16, noDict, noDictIssue, acceleration); + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)source > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, 0, notLimited, tableType, noDict, noDictIssue, acceleration); + } + } else { + if (inputSize < LZ4_64Klimit) { + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, byU16, noDict, noDictIssue, acceleration); + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)source > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, noDict, noDictIssue, acceleration); + } + } +} + +/** + * LZ4_compress_fast_extState_fastReset() : + * A variant of LZ4_compress_fast_extState(). + * + * Using this variant avoids an expensive initialization step. It is only safe + * to call if the state buffer is known to be correctly initialized already + * (see comment in lz4.h on LZ4_resetStream_fast() for a definition of + * "correctly initialized"). + */ +int LZ4_compress_fast_extState_fastReset(void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration) +{ + LZ4_stream_t_internal* ctx = &((LZ4_stream_t*)state)->internal_donotuse; + if (acceleration < 1) acceleration = ACCELERATION_DEFAULT; + + if (dstCapacity >= LZ4_compressBound(srcSize)) { + if (srcSize < LZ4_64Klimit) { + const tableType_t tableType = byU16; + LZ4_prepareTable(ctx, srcSize, tableType); + if (ctx->currentOffset) { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, 0, notLimited, tableType, noDict, dictSmall, acceleration); + } else { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, 0, notLimited, tableType, noDict, noDictIssue, acceleration); + } + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)src > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + LZ4_prepareTable(ctx, srcSize, tableType); + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, 0, notLimited, tableType, noDict, noDictIssue, acceleration); + } + } else { + if (srcSize < LZ4_64Klimit) { + const tableType_t tableType = byU16; + LZ4_prepareTable(ctx, srcSize, tableType); + if (ctx->currentOffset) { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, dstCapacity, limitedOutput, tableType, noDict, dictSmall, acceleration); + } else { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, dstCapacity, limitedOutput, tableType, noDict, noDictIssue, acceleration); + } + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)src > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + LZ4_prepareTable(ctx, srcSize, tableType); + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, dstCapacity, limitedOutput, tableType, noDict, noDictIssue, acceleration); + } + } +} + + +int LZ4_compress_fast(const char* source, char* dest, int inputSize, int maxOutputSize, int acceleration) +{ + int result; +#if (LZ4_HEAPMODE) + LZ4_stream_t* ctxPtr = ALLOC(sizeof(LZ4_stream_t)); /* malloc-calloc always properly aligned */ + if (ctxPtr == NULL) return 0; +#else + LZ4_stream_t ctx; + LZ4_stream_t* const ctxPtr = &ctx; +#endif + result = LZ4_compress_fast_extState(ctxPtr, source, dest, inputSize, maxOutputSize, acceleration); + +#if (LZ4_HEAPMODE) + FREEMEM(ctxPtr); +#endif + return result; +} + + +int LZ4_compress_default(const char* src, char* dst, int srcSize, int maxOutputSize) +{ + return LZ4_compress_fast(src, dst, srcSize, maxOutputSize, 1); +} + + +/* hidden debug function */ +/* strangely enough, gcc generates faster code when this function is uncommented, even if unused */ +int LZ4_compress_fast_force(const char* src, char* dst, int srcSize, int dstCapacity, int acceleration) +{ + LZ4_stream_t ctx; + LZ4_initStream(&ctx, sizeof(ctx)); + + if (srcSize < LZ4_64Klimit) { + return LZ4_compress_generic(&ctx.internal_donotuse, src, dst, srcSize, NULL, dstCapacity, limitedOutput, byU16, noDict, noDictIssue, acceleration); + } else { + tableType_t const addrMode = (sizeof(void*) > 4) ? byU32 : byPtr; + return LZ4_compress_generic(&ctx.internal_donotuse, src, dst, srcSize, NULL, dstCapacity, limitedOutput, addrMode, noDict, noDictIssue, acceleration); + } +} + + +/* Note!: This function leaves the stream in an unclean/broken state! + * It is not safe to subsequently use the same state with a _fastReset() or + * _continue() call without resetting it. */ +static int LZ4_compress_destSize_extState (LZ4_stream_t* state, const char* src, char* dst, int* srcSizePtr, int targetDstSize) +{ + void* const s = LZ4_initStream(state, sizeof (*state)); + assert(s != NULL); (void)s; + + if (targetDstSize >= LZ4_compressBound(*srcSizePtr)) { /* compression success is guaranteed */ + return LZ4_compress_fast_extState(state, src, dst, *srcSizePtr, targetDstSize, 1); + } else { + if (*srcSizePtr < LZ4_64Klimit) { + return LZ4_compress_generic(&state->internal_donotuse, src, dst, *srcSizePtr, srcSizePtr, targetDstSize, fillOutput, byU16, noDict, noDictIssue, 1); + } else { + tableType_t const addrMode = ((sizeof(void*)==4) && ((uptrval)src > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + return LZ4_compress_generic(&state->internal_donotuse, src, dst, *srcSizePtr, srcSizePtr, targetDstSize, fillOutput, addrMode, noDict, noDictIssue, 1); + } } +} + + +int LZ4_compress_destSize(const char* src, char* dst, int* srcSizePtr, int targetDstSize) +{ +#if (LZ4_HEAPMODE) + LZ4_stream_t* ctx = (LZ4_stream_t*)ALLOC(sizeof(LZ4_stream_t)); /* malloc-calloc always properly aligned */ + if (ctx == NULL) return 0; +#else + LZ4_stream_t ctxBody; + LZ4_stream_t* ctx = &ctxBody; +#endif + + int result = LZ4_compress_destSize_extState(ctx, src, dst, srcSizePtr, targetDstSize); + +#if (LZ4_HEAPMODE) + FREEMEM(ctx); +#endif + return result; +} + + + +/*-****************************** +* Streaming functions +********************************/ + +LZ4_stream_t* LZ4_createStream(void) +{ + LZ4_stream_t* const lz4s = (LZ4_stream_t*)ALLOC(sizeof(LZ4_stream_t)); + LZ4_STATIC_ASSERT(LZ4_STREAMSIZE >= sizeof(LZ4_stream_t_internal)); /* A compilation error here means LZ4_STREAMSIZE is not large enough */ + DEBUGLOG(4, "LZ4_createStream %p", lz4s); + if (lz4s == NULL) return NULL; + LZ4_initStream(lz4s, sizeof(*lz4s)); + return lz4s; +} + +#ifndef _MSC_VER /* for some reason, Visual fails the aligment test on 32-bit x86 : + it reports an aligment of 8-bytes, + while actually aligning LZ4_stream_t on 4 bytes. */ +static size_t LZ4_stream_t_alignment(void) +{ + struct { char c; LZ4_stream_t t; } t_a; + return sizeof(t_a) - sizeof(t_a.t); +} +#endif + +LZ4_stream_t* LZ4_initStream (void* buffer, size_t size) +{ + DEBUGLOG(5, "LZ4_initStream"); + if (buffer == NULL) { return NULL; } + if (size < sizeof(LZ4_stream_t)) { return NULL; } +#ifndef _MSC_VER /* for some reason, Visual fails the aligment test on 32-bit x86 : + it reports an aligment of 8-bytes, + while actually aligning LZ4_stream_t on 4 bytes. */ + if (((size_t)buffer) & (LZ4_stream_t_alignment() - 1)) { return NULL; } /* alignment check */ +#endif + MEM_INIT(buffer, 0, sizeof(LZ4_stream_t)); + return (LZ4_stream_t*)buffer; +} + +/* resetStream is now deprecated, + * prefer initStream() which is more general */ +void LZ4_resetStream (LZ4_stream_t* LZ4_stream) +{ + DEBUGLOG(5, "LZ4_resetStream (ctx:%p)", LZ4_stream); + MEM_INIT(LZ4_stream, 0, sizeof(LZ4_stream_t)); +} + +void LZ4_resetStream_fast(LZ4_stream_t* ctx) { + LZ4_prepareTable(&(ctx->internal_donotuse), 0, byU32); +} + +int LZ4_freeStream (LZ4_stream_t* LZ4_stream) +{ + if (!LZ4_stream) return 0; /* support free on NULL */ + DEBUGLOG(5, "LZ4_freeStream %p", LZ4_stream); + FREEMEM(LZ4_stream); + return (0); +} + + +#define HASH_UNIT sizeof(reg_t) +int LZ4_loadDict (LZ4_stream_t* LZ4_dict, const char* dictionary, int dictSize) +{ + LZ4_stream_t_internal* dict = &LZ4_dict->internal_donotuse; + const tableType_t tableType = byU32; + const BYTE* p = (const BYTE*)dictionary; + const BYTE* const dictEnd = p + dictSize; + const BYTE* base; + + DEBUGLOG(4, "LZ4_loadDict (%i bytes from %p into %p)", dictSize, dictionary, LZ4_dict); + + /* It's necessary to reset the context, + * and not just continue it with prepareTable() + * to avoid any risk of generating overflowing matchIndex + * when compressing using this dictionary */ + LZ4_resetStream(LZ4_dict); + + /* We always increment the offset by 64 KB, since, if the dict is longer, + * we truncate it to the last 64k, and if it's shorter, we still want to + * advance by a whole window length so we can provide the guarantee that + * there are only valid offsets in the window, which allows an optimization + * in LZ4_compress_fast_continue() where it uses noDictIssue even when the + * dictionary isn't a full 64k. */ + dict->currentOffset += 64 KB; + + if (dictSize < (int)HASH_UNIT) { + return 0; + } + + if ((dictEnd - p) > 64 KB) p = dictEnd - 64 KB; + base = dictEnd - dict->currentOffset; + dict->dictionary = p; + dict->dictSize = (U32)(dictEnd - p); + dict->tableType = tableType; + + while (p <= dictEnd-HASH_UNIT) { + LZ4_putPosition(p, dict->hashTable, tableType, base); + p+=3; + } + + return (int)dict->dictSize; +} + +void LZ4_attach_dictionary(LZ4_stream_t* workingStream, const LZ4_stream_t* dictionaryStream) { + const LZ4_stream_t_internal* dictCtx = dictionaryStream == NULL ? NULL : + &(dictionaryStream->internal_donotuse); + + DEBUGLOG(4, "LZ4_attach_dictionary (%p, %p, size %u)", + workingStream, dictionaryStream, + dictCtx != NULL ? dictCtx->dictSize : 0); + + /* Calling LZ4_resetStream_fast() here makes sure that changes will not be + * erased by subsequent calls to LZ4_resetStream_fast() in case stream was + * marked as having dirty context, e.g. requiring full reset. + */ + LZ4_resetStream_fast(workingStream); + + if (dictCtx != NULL) { + /* If the current offset is zero, we will never look in the + * external dictionary context, since there is no value a table + * entry can take that indicate a miss. In that case, we need + * to bump the offset to something non-zero. + */ + if (workingStream->internal_donotuse.currentOffset == 0) { + workingStream->internal_donotuse.currentOffset = 64 KB; + } + + /* Don't actually attach an empty dictionary. + */ + if (dictCtx->dictSize == 0) { + dictCtx = NULL; + } + } + workingStream->internal_donotuse.dictCtx = dictCtx; +} + + +static void LZ4_renormDictT(LZ4_stream_t_internal* LZ4_dict, int nextSize) +{ + assert(nextSize >= 0); + if (LZ4_dict->currentOffset + (unsigned)nextSize > 0x80000000) { /* potential ptrdiff_t overflow (32-bits mode) */ + /* rescale hash table */ + U32 const delta = LZ4_dict->currentOffset - 64 KB; + const BYTE* dictEnd = LZ4_dict->dictionary + LZ4_dict->dictSize; + int i; + DEBUGLOG(4, "LZ4_renormDictT"); + for (i=0; ihashTable[i] < delta) LZ4_dict->hashTable[i]=0; + else LZ4_dict->hashTable[i] -= delta; + } + LZ4_dict->currentOffset = 64 KB; + if (LZ4_dict->dictSize > 64 KB) LZ4_dict->dictSize = 64 KB; + LZ4_dict->dictionary = dictEnd - LZ4_dict->dictSize; + } +} + + +int LZ4_compress_fast_continue (LZ4_stream_t* LZ4_stream, + const char* source, char* dest, + int inputSize, int maxOutputSize, + int acceleration) +{ + const tableType_t tableType = byU32; + LZ4_stream_t_internal* streamPtr = &LZ4_stream->internal_donotuse; + const BYTE* dictEnd = streamPtr->dictionary + streamPtr->dictSize; + + DEBUGLOG(5, "LZ4_compress_fast_continue (inputSize=%i)", inputSize); + + if (streamPtr->dirty) { return 0; } /* Uninitialized structure detected */ + LZ4_renormDictT(streamPtr, inputSize); /* avoid index overflow */ + if (acceleration < 1) acceleration = ACCELERATION_DEFAULT; + + /* invalidate tiny dictionaries */ + if ( (streamPtr->dictSize-1 < 4-1) /* intentional underflow */ + && (dictEnd != (const BYTE*)source) ) { + DEBUGLOG(5, "LZ4_compress_fast_continue: dictSize(%u) at addr:%p is too small", streamPtr->dictSize, streamPtr->dictionary); + streamPtr->dictSize = 0; + streamPtr->dictionary = (const BYTE*)source; + dictEnd = (const BYTE*)source; + } + + /* Check overlapping input/dictionary space */ + { const BYTE* sourceEnd = (const BYTE*) source + inputSize; + if ((sourceEnd > streamPtr->dictionary) && (sourceEnd < dictEnd)) { + streamPtr->dictSize = (U32)(dictEnd - sourceEnd); + if (streamPtr->dictSize > 64 KB) streamPtr->dictSize = 64 KB; + if (streamPtr->dictSize < 4) streamPtr->dictSize = 0; + streamPtr->dictionary = dictEnd - streamPtr->dictSize; + } + } + + /* prefix mode : source data follows dictionary */ + if (dictEnd == (const BYTE*)source) { + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) + return LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, withPrefix64k, dictSmall, acceleration); + else + return LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, withPrefix64k, noDictIssue, acceleration); + } + + /* external dictionary mode */ + { int result; + if (streamPtr->dictCtx) { + /* We depend here on the fact that dictCtx'es (produced by + * LZ4_loadDict) guarantee that their tables contain no references + * to offsets between dictCtx->currentOffset - 64 KB and + * dictCtx->currentOffset - dictCtx->dictSize. This makes it safe + * to use noDictIssue even when the dict isn't a full 64 KB. + */ + if (inputSize > 4 KB) { + /* For compressing large blobs, it is faster to pay the setup + * cost to copy the dictionary's tables into the active context, + * so that the compression loop is only looking into one table. + */ + memcpy(streamPtr, streamPtr->dictCtx, sizeof(LZ4_stream_t)); + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingExtDict, noDictIssue, acceleration); + } else { + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingDictCtx, noDictIssue, acceleration); + } + } else { + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) { + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingExtDict, dictSmall, acceleration); + } else { + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingExtDict, noDictIssue, acceleration); + } + } + streamPtr->dictionary = (const BYTE*)source; + streamPtr->dictSize = (U32)inputSize; + return result; + } +} + + +/* Hidden debug function, to force-test external dictionary mode */ +int LZ4_compress_forceExtDict (LZ4_stream_t* LZ4_dict, const char* source, char* dest, int srcSize) +{ + LZ4_stream_t_internal* streamPtr = &LZ4_dict->internal_donotuse; + int result; + + LZ4_renormDictT(streamPtr, srcSize); + + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) { + result = LZ4_compress_generic(streamPtr, source, dest, srcSize, NULL, 0, notLimited, byU32, usingExtDict, dictSmall, 1); + } else { + result = LZ4_compress_generic(streamPtr, source, dest, srcSize, NULL, 0, notLimited, byU32, usingExtDict, noDictIssue, 1); + } + + streamPtr->dictionary = (const BYTE*)source; + streamPtr->dictSize = (U32)srcSize; + + return result; +} + + +/*! LZ4_saveDict() : + * If previously compressed data block is not guaranteed to remain available at its memory location, + * save it into a safer place (char* safeBuffer). + * Note : you don't need to call LZ4_loadDict() afterwards, + * dictionary is immediately usable, you can therefore call LZ4_compress_fast_continue(). + * Return : saved dictionary size in bytes (necessarily <= dictSize), or 0 if error. + */ +int LZ4_saveDict (LZ4_stream_t* LZ4_dict, char* safeBuffer, int dictSize) +{ + LZ4_stream_t_internal* const dict = &LZ4_dict->internal_donotuse; + const BYTE* const previousDictEnd = dict->dictionary + dict->dictSize; + + if ((U32)dictSize > 64 KB) { dictSize = 64 KB; } /* useless to define a dictionary > 64 KB */ + if ((U32)dictSize > dict->dictSize) { dictSize = (int)dict->dictSize; } + + memmove(safeBuffer, previousDictEnd - dictSize, dictSize); + + dict->dictionary = (const BYTE*)safeBuffer; + dict->dictSize = (U32)dictSize; + + return dictSize; +} + + + +/*-******************************* + * Decompression functions + ********************************/ + +typedef enum { endOnOutputSize = 0, endOnInputSize = 1 } endCondition_directive; +typedef enum { decode_full_block = 0, partial_decode = 1 } earlyEnd_directive; + +#undef MIN +#define MIN(a,b) ( (a) < (b) ? (a) : (b) ) + +/* Read the variable-length literal or match length. + * + * ip - pointer to use as input. + * lencheck - end ip. Return an error if ip advances >= lencheck. + * loop_check - check ip >= lencheck in body of loop. Returns loop_error if so. + * initial_check - check ip >= lencheck before start of loop. Returns initial_error if so. + * error (output) - error code. Should be set to 0 before call. + */ +typedef enum { loop_error = -2, initial_error = -1, ok = 0 } variable_length_error; +LZ4_FORCE_INLINE unsigned +read_variable_length(const BYTE**ip, const BYTE* lencheck, int loop_check, int initial_check, variable_length_error* error) +{ + U32 length = 0; + U32 s; + if (initial_check && unlikely((*ip) >= lencheck)) { /* overflow detection */ + *error = initial_error; + return length; + } + do { + s = **ip; + (*ip)++; + length += s; + if (loop_check && unlikely((*ip) >= lencheck)) { /* overflow detection */ + *error = loop_error; + return length; + } + } while (s==255); + + return length; +} + +/*! LZ4_decompress_generic() : + * This generic decompression function covers all use cases. + * It shall be instantiated several times, using different sets of directives. + * Note that it is important for performance that this function really get inlined, + * in order to remove useless branches during compilation optimization. + */ +LZ4_FORCE_INLINE int +LZ4_decompress_generic( + const char* const src, + char* const dst, + int srcSize, + int outputSize, /* If endOnInput==endOnInputSize, this value is `dstCapacity` */ + + endCondition_directive endOnInput, /* endOnOutputSize, endOnInputSize */ + earlyEnd_directive partialDecoding, /* full, partial */ + dict_directive dict, /* noDict, withPrefix64k, usingExtDict */ + const BYTE* const lowPrefix, /* always <= dst, == dst when no prefix */ + const BYTE* const dictStart, /* only if dict==usingExtDict */ + const size_t dictSize /* note : = 0 if noDict */ + ) +{ + if (src == NULL) { return -1; } + + { const BYTE* ip = (const BYTE*) src; + const BYTE* const iend = ip + srcSize; + + BYTE* op = (BYTE*) dst; + BYTE* const oend = op + outputSize; + BYTE* cpy; + + const BYTE* const dictEnd = (dictStart == NULL) ? NULL : dictStart + dictSize; + + const int safeDecode = (endOnInput==endOnInputSize); + const int checkOffset = ((safeDecode) && (dictSize < (int)(64 KB))); + + + /* Set up the "end" pointers for the shortcut. */ + const BYTE* const shortiend = iend - (endOnInput ? 14 : 8) /*maxLL*/ - 2 /*offset*/; + const BYTE* const shortoend = oend - (endOnInput ? 14 : 8) /*maxLL*/ - 18 /*maxML*/; + + const BYTE* match; + size_t offset; + unsigned token; + size_t length; + + + DEBUGLOG(5, "LZ4_decompress_generic (srcSize:%i, dstSize:%i)", srcSize, outputSize); + + /* Special cases */ + assert(lowPrefix <= op); + if ((endOnInput) && (unlikely(outputSize==0))) { + /* Empty output buffer */ + if (partialDecoding) return 0; + return ((srcSize==1) && (*ip==0)) ? 0 : -1; + } + if ((!endOnInput) && (unlikely(outputSize==0))) { return (*ip==0 ? 1 : -1); } + if ((endOnInput) && unlikely(srcSize==0)) { return -1; } + + /* Currently the fast loop shows a regression on qualcomm arm chips. */ +#if LZ4_FAST_DEC_LOOP + if ((oend - op) < FASTLOOP_SAFE_DISTANCE) { + DEBUGLOG(6, "skip fast decode loop"); + goto safe_decode; + } + + /* Fast loop : decode sequences as long as output < iend-FASTLOOP_SAFE_DISTANCE */ + while (1) { + /* Main fastloop assertion: We can always wildcopy FASTLOOP_SAFE_DISTANCE */ + assert(oend - op >= FASTLOOP_SAFE_DISTANCE); + if (endOnInput) { assert(ip < iend); } + token = *ip++; + length = token >> ML_BITS; /* literal length */ + + assert(!endOnInput || ip <= iend); /* ip < iend before the increment */ + + /* decode literal length */ + if (length == RUN_MASK) { + variable_length_error error = ok; + length += read_variable_length(&ip, iend-RUN_MASK, endOnInput, endOnInput, &error); + if (error == initial_error) { goto _output_error; } + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)(op))) { goto _output_error; } /* overflow detection */ + if ((safeDecode) && unlikely((uptrval)(ip)+length<(uptrval)(ip))) { goto _output_error; } /* overflow detection */ + + /* copy literals */ + cpy = op+length; + LZ4_STATIC_ASSERT(MFLIMIT >= WILDCOPYLENGTH); + if (endOnInput) { /* LZ4_decompress_safe() */ + if ((cpy>oend-32) || (ip+length>iend-32)) { goto safe_literal_copy; } + LZ4_wildCopy32(op, ip, cpy); + } else { /* LZ4_decompress_fast() */ + if (cpy>oend-8) { goto safe_literal_copy; } + LZ4_wildCopy8(op, ip, cpy); /* LZ4_decompress_fast() cannot copy more than 8 bytes at a time : + * it doesn't know input length, and only relies on end-of-block properties */ + } + ip += length; op = cpy; + } else { + cpy = op+length; + if (endOnInput) { /* LZ4_decompress_safe() */ + DEBUGLOG(7, "copy %u bytes in a 16-bytes stripe", (unsigned)length); + /* We don't need to check oend, since we check it once for each loop below */ + if (ip > iend-(16 + 1/*max lit + offset + nextToken*/)) { goto safe_literal_copy; } + /* Literals can only be 14, but hope compilers optimize if we copy by a register size */ + memcpy(op, ip, 16); + } else { /* LZ4_decompress_fast() */ + /* LZ4_decompress_fast() cannot copy more than 8 bytes at a time : + * it doesn't know input length, and relies on end-of-block properties */ + memcpy(op, ip, 8); + if (length > 8) { memcpy(op+8, ip+8, 8); } + } + ip += length; op = cpy; + } + + /* get offset */ + offset = LZ4_readLE16(ip); ip+=2; + match = op - offset; + assert(match <= op); + + /* get matchlength */ + length = token & ML_MASK; + + if (length == ML_MASK) { + variable_length_error error = ok; + if ((checkOffset) && (unlikely(match + dictSize < lowPrefix))) { goto _output_error; } /* Error : offset outside buffers */ + length += read_variable_length(&ip, iend - LASTLITERALS + 1, endOnInput, 0, &error); + if (error != ok) { goto _output_error; } + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)op)) { goto _output_error; } /* overflow detection */ + length += MINMATCH; + if (op + length >= oend - FASTLOOP_SAFE_DISTANCE) { + goto safe_match_copy; + } + } else { + length += MINMATCH; + if (op + length >= oend - FASTLOOP_SAFE_DISTANCE) { + goto safe_match_copy; + } + + /* Fastpath check: Avoids a branch in LZ4_wildCopy32 if true */ + if ((dict == withPrefix64k) || (match >= lowPrefix)) { + if (offset >= 8) { + assert(match >= lowPrefix); + assert(match <= op); + assert(op + 18 <= oend); + + memcpy(op, match, 8); + memcpy(op+8, match+8, 8); + memcpy(op+16, match+16, 2); + op += length; + continue; + } } } + + if ((checkOffset) && (unlikely(match + dictSize < lowPrefix))) { goto _output_error; } /* Error : offset outside buffers */ + /* match starting within external dictionary */ + if ((dict==usingExtDict) && (match < lowPrefix)) { + if (unlikely(op+length > oend-LASTLITERALS)) { + if (partialDecoding) { + length = MIN(length, (size_t)(oend-op)); /* reach end of buffer */ + } else { + goto _output_error; /* end-of-block condition violated */ + } } + + if (length <= (size_t)(lowPrefix-match)) { + /* match fits entirely within external dictionary : just copy */ + memmove(op, dictEnd - (lowPrefix-match), length); + op += length; + } else { + /* match stretches into both external dictionary and current block */ + size_t const copySize = (size_t)(lowPrefix - match); + size_t const restSize = length - copySize; + memcpy(op, dictEnd - copySize, copySize); + op += copySize; + if (restSize > (size_t)(op - lowPrefix)) { /* overlap copy */ + BYTE* const endOfMatch = op + restSize; + const BYTE* copyFrom = lowPrefix; + while (op < endOfMatch) { *op++ = *copyFrom++; } + } else { + memcpy(op, lowPrefix, restSize); + op += restSize; + } } + continue; + } + + /* copy match within block */ + cpy = op + length; + + assert((op <= oend) && (oend-op >= 32)); + if (unlikely(offset<16)) { + LZ4_memcpy_using_offset(op, match, cpy, offset); + } else { + LZ4_wildCopy32(op, match, cpy); + } + + op = cpy; /* wildcopy correction */ + } + safe_decode: +#endif + + /* Main Loop : decode remaining sequences where output < FASTLOOP_SAFE_DISTANCE */ + while (1) { + token = *ip++; + length = token >> ML_BITS; /* literal length */ + + assert(!endOnInput || ip <= iend); /* ip < iend before the increment */ + + /* A two-stage shortcut for the most common case: + * 1) If the literal length is 0..14, and there is enough space, + * enter the shortcut and copy 16 bytes on behalf of the literals + * (in the fast mode, only 8 bytes can be safely copied this way). + * 2) Further if the match length is 4..18, copy 18 bytes in a similar + * manner; but we ensure that there's enough space in the output for + * those 18 bytes earlier, upon entering the shortcut (in other words, + * there is a combined check for both stages). + */ + if ( (endOnInput ? length != RUN_MASK : length <= 8) + /* strictly "less than" on input, to re-enter the loop with at least one byte */ + && likely((endOnInput ? ip < shortiend : 1) & (op <= shortoend)) ) { + /* Copy the literals */ + memcpy(op, ip, endOnInput ? 16 : 8); + op += length; ip += length; + + /* The second stage: prepare for match copying, decode full info. + * If it doesn't work out, the info won't be wasted. */ + length = token & ML_MASK; /* match length */ + offset = LZ4_readLE16(ip); ip += 2; + match = op - offset; + assert(match <= op); /* check overflow */ + + /* Do not deal with overlapping matches. */ + if ( (length != ML_MASK) + && (offset >= 8) + && (dict==withPrefix64k || match >= lowPrefix) ) { + /* Copy the match. */ + memcpy(op + 0, match + 0, 8); + memcpy(op + 8, match + 8, 8); + memcpy(op +16, match +16, 2); + op += length + MINMATCH; + /* Both stages worked, load the next token. */ + continue; + } + + /* The second stage didn't work out, but the info is ready. + * Propel it right to the point of match copying. */ + goto _copy_match; + } + + /* decode literal length */ + if (length == RUN_MASK) { + variable_length_error error = ok; + length += read_variable_length(&ip, iend-RUN_MASK, endOnInput, endOnInput, &error); + if (error == initial_error) { goto _output_error; } + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)(op))) { goto _output_error; } /* overflow detection */ + if ((safeDecode) && unlikely((uptrval)(ip)+length<(uptrval)(ip))) { goto _output_error; } /* overflow detection */ + } + + /* copy literals */ + cpy = op+length; +#if LZ4_FAST_DEC_LOOP + safe_literal_copy: +#endif + LZ4_STATIC_ASSERT(MFLIMIT >= WILDCOPYLENGTH); + if ( ((endOnInput) && ((cpy>oend-MFLIMIT) || (ip+length>iend-(2+1+LASTLITERALS))) ) + || ((!endOnInput) && (cpy>oend-WILDCOPYLENGTH)) ) + { + /* We've either hit the input parsing restriction or the output parsing restriction. + * If we've hit the input parsing condition then this must be the last sequence. + * If we've hit the output parsing condition then we are either using partialDecoding + * or we've hit the output parsing condition. + */ + if (partialDecoding) { + /* Since we are partial decoding we may be in this block because of the output parsing + * restriction, which is not valid since the output buffer is allowed to be undersized. + */ + assert(endOnInput); + /* If we're in this block because of the input parsing condition, then we must be on the + * last sequence (or invalid), so we must check that we exactly consume the input. + */ + if ((ip+length>iend-(2+1+LASTLITERALS)) && (ip+length != iend)) { goto _output_error; } + assert(ip+length <= iend); + /* We are finishing in the middle of a literals segment. + * Break after the copy. + */ + if (cpy > oend) { + cpy = oend; + assert(op<=oend); + length = (size_t)(oend-op); + } + assert(ip+length <= iend); + } else { + /* We must be on the last sequence because of the parsing limitations so check + * that we exactly regenerate the original size (must be exact when !endOnInput). + */ + if ((!endOnInput) && (cpy != oend)) { goto _output_error; } + /* We must be on the last sequence (or invalid) because of the parsing limitations + * so check that we exactly consume the input and don't overrun the output buffer. + */ + if ((endOnInput) && ((ip+length != iend) || (cpy > oend))) { goto _output_error; } + } + memmove(op, ip, length); /* supports overlapping memory regions, which only matters for in-place decompression scenarios */ + ip += length; + op += length; + /* Necessarily EOF when !partialDecoding. When partialDecoding + * it is EOF if we've either filled the output buffer or hit + * the input parsing restriction. + */ + if (!partialDecoding || (cpy == oend) || (ip == iend)) { + break; + } + } else { + LZ4_wildCopy8(op, ip, cpy); /* may overwrite up to WILDCOPYLENGTH beyond cpy */ + ip += length; op = cpy; + } + + /* get offset */ + offset = LZ4_readLE16(ip); ip+=2; + match = op - offset; + + /* get matchlength */ + length = token & ML_MASK; + + _copy_match: + if (length == ML_MASK) { + variable_length_error error = ok; + length += read_variable_length(&ip, iend - LASTLITERALS + 1, endOnInput, 0, &error); + if (error != ok) goto _output_error; + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)op)) goto _output_error; /* overflow detection */ + } + length += MINMATCH; + +#if LZ4_FAST_DEC_LOOP + safe_match_copy: +#endif + if ((checkOffset) && (unlikely(match + dictSize < lowPrefix))) goto _output_error; /* Error : offset outside buffers */ + /* match starting within external dictionary */ + if ((dict==usingExtDict) && (match < lowPrefix)) { + if (unlikely(op+length > oend-LASTLITERALS)) { + if (partialDecoding) length = MIN(length, (size_t)(oend-op)); + else goto _output_error; /* doesn't respect parsing restriction */ + } + + if (length <= (size_t)(lowPrefix-match)) { + /* match fits entirely within external dictionary : just copy */ + memmove(op, dictEnd - (lowPrefix-match), length); + op += length; + } else { + /* match stretches into both external dictionary and current block */ + size_t const copySize = (size_t)(lowPrefix - match); + size_t const restSize = length - copySize; + memcpy(op, dictEnd - copySize, copySize); + op += copySize; + if (restSize > (size_t)(op - lowPrefix)) { /* overlap copy */ + BYTE* const endOfMatch = op + restSize; + const BYTE* copyFrom = lowPrefix; + while (op < endOfMatch) *op++ = *copyFrom++; + } else { + memcpy(op, lowPrefix, restSize); + op += restSize; + } } + continue; + } + assert(match >= lowPrefix); + + /* copy match within block */ + cpy = op + length; + + /* partialDecoding : may end anywhere within the block */ + assert(op<=oend); + if (partialDecoding && (cpy > oend-MATCH_SAFEGUARD_DISTANCE)) { + size_t const mlen = MIN(length, (size_t)(oend-op)); + const BYTE* const matchEnd = match + mlen; + BYTE* const copyEnd = op + mlen; + if (matchEnd > op) { /* overlap copy */ + while (op < copyEnd) { *op++ = *match++; } + } else { + memcpy(op, match, mlen); + } + op = copyEnd; + if (op == oend) { break; } + continue; + } + + if (unlikely(offset<8)) { + LZ4_write32(op, 0); /* silence msan warning when offset==0 */ + op[0] = match[0]; + op[1] = match[1]; + op[2] = match[2]; + op[3] = match[3]; + match += inc32table[offset]; + memcpy(op+4, match, 4); + match -= dec64table[offset]; + } else { + memcpy(op, match, 8); + match += 8; + } + op += 8; + + if (unlikely(cpy > oend-MATCH_SAFEGUARD_DISTANCE)) { + BYTE* const oCopyLimit = oend - (WILDCOPYLENGTH-1); + if (cpy > oend-LASTLITERALS) { goto _output_error; } /* Error : last LASTLITERALS bytes must be literals (uncompressed) */ + if (op < oCopyLimit) { + LZ4_wildCopy8(op, match, oCopyLimit); + match += oCopyLimit - op; + op = oCopyLimit; + } + while (op < cpy) { *op++ = *match++; } + } else { + memcpy(op, match, 8); + if (length > 16) { LZ4_wildCopy8(op+8, match+8, cpy); } + } + op = cpy; /* wildcopy correction */ + } + + /* end of decoding */ + if (endOnInput) { + return (int) (((char*)op)-dst); /* Nb of output bytes decoded */ + } else { + return (int) (((const char*)ip)-src); /* Nb of input bytes read */ + } + + /* Overflow error detected */ + _output_error: + return (int) (-(((const char*)ip)-src))-1; + } +} + + +/*===== Instantiate the API decoding functions. =====*/ + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe(const char* source, char* dest, int compressedSize, int maxDecompressedSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxDecompressedSize, + endOnInputSize, decode_full_block, noDict, + (BYTE*)dest, NULL, 0); +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe_partial(const char* src, char* dst, int compressedSize, int targetOutputSize, int dstCapacity) +{ + dstCapacity = MIN(targetOutputSize, dstCapacity); + return LZ4_decompress_generic(src, dst, compressedSize, dstCapacity, + endOnInputSize, partial_decode, + noDict, (BYTE*)dst, NULL, 0); +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_fast(const char* source, char* dest, int originalSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, decode_full_block, withPrefix64k, + (BYTE*)dest - 64 KB, NULL, 0); +} + +/*===== Instantiate a few more decoding cases, used more than once. =====*/ + +LZ4_FORCE_O2_GCC_PPC64LE /* Exported, an obsolete API function. */ +int LZ4_decompress_safe_withPrefix64k(const char* source, char* dest, int compressedSize, int maxOutputSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, withPrefix64k, + (BYTE*)dest - 64 KB, NULL, 0); +} + +/* Another obsolete API function, paired with the previous one. */ +int LZ4_decompress_fast_withPrefix64k(const char* source, char* dest, int originalSize) +{ + /* LZ4_decompress_fast doesn't validate match offsets, + * and thus serves well with any prefixed dictionary. */ + return LZ4_decompress_fast(source, dest, originalSize); +} + +LZ4_FORCE_O2_GCC_PPC64LE +static int LZ4_decompress_safe_withSmallPrefix(const char* source, char* dest, int compressedSize, int maxOutputSize, + size_t prefixSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, noDict, + (BYTE*)dest-prefixSize, NULL, 0); +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe_forceExtDict(const char* source, char* dest, + int compressedSize, int maxOutputSize, + const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, usingExtDict, + (BYTE*)dest, (const BYTE*)dictStart, dictSize); +} + +LZ4_FORCE_O2_GCC_PPC64LE +static int LZ4_decompress_fast_extDict(const char* source, char* dest, int originalSize, + const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, decode_full_block, usingExtDict, + (BYTE*)dest, (const BYTE*)dictStart, dictSize); +} + +/* The "double dictionary" mode, for use with e.g. ring buffers: the first part + * of the dictionary is passed as prefix, and the second via dictStart + dictSize. + * These routines are used only once, in LZ4_decompress_*_continue(). + */ +LZ4_FORCE_INLINE +int LZ4_decompress_safe_doubleDict(const char* source, char* dest, int compressedSize, int maxOutputSize, + size_t prefixSize, const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, usingExtDict, + (BYTE*)dest-prefixSize, (const BYTE*)dictStart, dictSize); +} + +LZ4_FORCE_INLINE +int LZ4_decompress_fast_doubleDict(const char* source, char* dest, int originalSize, + size_t prefixSize, const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, decode_full_block, usingExtDict, + (BYTE*)dest-prefixSize, (const BYTE*)dictStart, dictSize); +} + +/*===== streaming decompression functions =====*/ + +LZ4_streamDecode_t* LZ4_createStreamDecode(void) +{ + LZ4_streamDecode_t* lz4s = (LZ4_streamDecode_t*) ALLOC_AND_ZERO(sizeof(LZ4_streamDecode_t)); + LZ4_STATIC_ASSERT(LZ4_STREAMDECODESIZE >= sizeof(LZ4_streamDecode_t_internal)); /* A compilation error here means LZ4_STREAMDECODESIZE is not large enough */ + return lz4s; +} + +int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream) +{ + if (LZ4_stream == NULL) { return 0; } /* support free on NULL */ + FREEMEM(LZ4_stream); + return 0; +} + +/*! LZ4_setStreamDecode() : + * Use this function to instruct where to find the dictionary. + * This function is not necessary if previous data is still available where it was decoded. + * Loading a size of 0 is allowed (same effect as no dictionary). + * @return : 1 if OK, 0 if error + */ +int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize) +{ + LZ4_streamDecode_t_internal* lz4sd = &LZ4_streamDecode->internal_donotuse; + lz4sd->prefixSize = (size_t) dictSize; + lz4sd->prefixEnd = (const BYTE*) dictionary + dictSize; + lz4sd->externalDict = NULL; + lz4sd->extDictSize = 0; + return 1; +} + +/*! LZ4_decoderRingBufferSize() : + * when setting a ring buffer for streaming decompression (optional scenario), + * provides the minimum size of this ring buffer + * to be compatible with any source respecting maxBlockSize condition. + * Note : in a ring buffer scenario, + * blocks are presumed decompressed next to each other. + * When not enough space remains for next block (remainingSize < maxBlockSize), + * decoding resumes from beginning of ring buffer. + * @return : minimum ring buffer size, + * or 0 if there is an error (invalid maxBlockSize). + */ +int LZ4_decoderRingBufferSize(int maxBlockSize) +{ + if (maxBlockSize < 0) return 0; + if (maxBlockSize > LZ4_MAX_INPUT_SIZE) return 0; + if (maxBlockSize < 16) maxBlockSize = 16; + return LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize); +} + +/* +*_continue() : + These decoding functions allow decompression of multiple blocks in "streaming" mode. + Previously decoded blocks must still be available at the memory position where they were decoded. + If it's not possible, save the relevant part of decoded data into a safe buffer, + and indicate where it stands using LZ4_setStreamDecode() +*/ +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int compressedSize, int maxOutputSize) +{ + LZ4_streamDecode_t_internal* lz4sd = &LZ4_streamDecode->internal_donotuse; + int result; + + if (lz4sd->prefixSize == 0) { + /* The first call, no dictionary yet. */ + assert(lz4sd->extDictSize == 0); + result = LZ4_decompress_safe(source, dest, compressedSize, maxOutputSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)result; + lz4sd->prefixEnd = (BYTE*)dest + result; + } else if (lz4sd->prefixEnd == (BYTE*)dest) { + /* They're rolling the current segment. */ + if (lz4sd->prefixSize >= 64 KB - 1) + result = LZ4_decompress_safe_withPrefix64k(source, dest, compressedSize, maxOutputSize); + else if (lz4sd->extDictSize == 0) + result = LZ4_decompress_safe_withSmallPrefix(source, dest, compressedSize, maxOutputSize, + lz4sd->prefixSize); + else + result = LZ4_decompress_safe_doubleDict(source, dest, compressedSize, maxOutputSize, + lz4sd->prefixSize, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize += (size_t)result; + lz4sd->prefixEnd += result; + } else { + /* The buffer wraps around, or they're switching to another buffer. */ + lz4sd->extDictSize = lz4sd->prefixSize; + lz4sd->externalDict = lz4sd->prefixEnd - lz4sd->extDictSize; + result = LZ4_decompress_safe_forceExtDict(source, dest, compressedSize, maxOutputSize, + lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)result; + lz4sd->prefixEnd = (BYTE*)dest + result; + } + + return result; +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int originalSize) +{ + LZ4_streamDecode_t_internal* lz4sd = &LZ4_streamDecode->internal_donotuse; + int result; + assert(originalSize >= 0); + + if (lz4sd->prefixSize == 0) { + assert(lz4sd->extDictSize == 0); + result = LZ4_decompress_fast(source, dest, originalSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)originalSize; + lz4sd->prefixEnd = (BYTE*)dest + originalSize; + } else if (lz4sd->prefixEnd == (BYTE*)dest) { + if (lz4sd->prefixSize >= 64 KB - 1 || lz4sd->extDictSize == 0) + result = LZ4_decompress_fast(source, dest, originalSize); + else + result = LZ4_decompress_fast_doubleDict(source, dest, originalSize, + lz4sd->prefixSize, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize += (size_t)originalSize; + lz4sd->prefixEnd += originalSize; + } else { + lz4sd->extDictSize = lz4sd->prefixSize; + lz4sd->externalDict = lz4sd->prefixEnd - lz4sd->extDictSize; + result = LZ4_decompress_fast_extDict(source, dest, originalSize, + lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)originalSize; + lz4sd->prefixEnd = (BYTE*)dest + originalSize; + } + + return result; +} + + +/* +Advanced decoding functions : +*_usingDict() : + These decoding functions work the same as "_continue" ones, + the dictionary must be explicitly provided within parameters +*/ + +int LZ4_decompress_safe_usingDict(const char* source, char* dest, int compressedSize, int maxOutputSize, const char* dictStart, int dictSize) +{ + if (dictSize==0) + return LZ4_decompress_safe(source, dest, compressedSize, maxOutputSize); + if (dictStart+dictSize == dest) { + if (dictSize >= 64 KB - 1) { + return LZ4_decompress_safe_withPrefix64k(source, dest, compressedSize, maxOutputSize); + } + assert(dictSize >= 0); + return LZ4_decompress_safe_withSmallPrefix(source, dest, compressedSize, maxOutputSize, (size_t)dictSize); + } + assert(dictSize >= 0); + return LZ4_decompress_safe_forceExtDict(source, dest, compressedSize, maxOutputSize, dictStart, (size_t)dictSize); +} + +int LZ4_decompress_fast_usingDict(const char* source, char* dest, int originalSize, const char* dictStart, int dictSize) +{ + if (dictSize==0 || dictStart+dictSize == dest) + return LZ4_decompress_fast(source, dest, originalSize); + assert(dictSize >= 0); + return LZ4_decompress_fast_extDict(source, dest, originalSize, dictStart, (size_t)dictSize); +} + + +/*=************************************************* +* Obsolete Functions +***************************************************/ +/* obsolete compression functions */ +int LZ4_compress_limitedOutput(const char* source, char* dest, int inputSize, int maxOutputSize) +{ + return LZ4_compress_default(source, dest, inputSize, maxOutputSize); +} +int LZ4_compress(const char* src, char* dest, int srcSize) +{ + return LZ4_compress_default(src, dest, srcSize, LZ4_compressBound(srcSize)); +} +int LZ4_compress_limitedOutput_withState (void* state, const char* src, char* dst, int srcSize, int dstSize) +{ + return LZ4_compress_fast_extState(state, src, dst, srcSize, dstSize, 1); +} +int LZ4_compress_withState (void* state, const char* src, char* dst, int srcSize) +{ + return LZ4_compress_fast_extState(state, src, dst, srcSize, LZ4_compressBound(srcSize), 1); +} +int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_stream, const char* src, char* dst, int srcSize, int dstCapacity) +{ + return LZ4_compress_fast_continue(LZ4_stream, src, dst, srcSize, dstCapacity, 1); +} +int LZ4_compress_continue (LZ4_stream_t* LZ4_stream, const char* source, char* dest, int inputSize) +{ + return LZ4_compress_fast_continue(LZ4_stream, source, dest, inputSize, LZ4_compressBound(inputSize), 1); +} + +/* +These decompression functions are deprecated and should no longer be used. +They are only provided here for compatibility with older user programs. +- LZ4_uncompress is totally equivalent to LZ4_decompress_fast +- LZ4_uncompress_unknownOutputSize is totally equivalent to LZ4_decompress_safe +*/ +int LZ4_uncompress (const char* source, char* dest, int outputSize) +{ + return LZ4_decompress_fast(source, dest, outputSize); +} +int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize) +{ + return LZ4_decompress_safe(source, dest, isize, maxOutputSize); +} + +/* Obsolete Streaming functions */ + +int LZ4_sizeofStreamState() { return LZ4_STREAMSIZE; } + +int LZ4_resetStreamState(void* state, char* inputBuffer) +{ + (void)inputBuffer; + LZ4_resetStream((LZ4_stream_t*)state); + return 0; +} + +void* LZ4_create (char* inputBuffer) +{ + (void)inputBuffer; + return LZ4_createStream(); +} + +char* LZ4_slideInputBuffer (void* state) +{ + /* avoid const char * -> char * conversion warning */ + return (char *)(uptrval)((LZ4_stream_t*)state)->internal_donotuse.dictionary; +} + +#endif /* LZ4_COMMONDEFS_ONLY */ \ No newline at end of file diff --git a/src/dataset_creation/dataset_manipulation/bytes_to_file.cpp b/src/dataset_creation/dataset_manipulation/bytes_to_file.cpp new file mode 100644 index 0000000000000000000000000000000000000000..691e9f98800ae9326f084505eb8bf43ad198ed7b --- /dev/null +++ b/src/dataset_creation/dataset_manipulation/bytes_to_file.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include "../../../libraries/protobuf/build/midi.pb.h" +#include "../../../include/dataset_creation/dataset_manipulation/bytes_to_file.h" +#include "../../../include/dataset_creation/compression/lz4.h" +#include "../../../include/dataset_creation/dataset_manipulation/bytes_to_file.h" + +namespace dataset_manipulation { + + BytesToFile::BytesToFile(std::string user_filepath_) { + filepath_ = user_filepath_; + header_filepath_ = user_filepath_ + ".header"; + flush_count_ = 0; + can_write = false; + } + + void BytesToFile::enableWrite() { + if (can_write) { return; } + // check that the current file is empty unless force flag is present ? + file_stream_.open(filepath_, std::ios::out | std::ios::binary); + can_write = true; + } + + void BytesToFile::appendBytesToFileStream(std::string& bytes_as_string, size_t split_id) { + //file_stream_.open(filepath_, std::ios::out | std::ios::binary); + enableWrite(); + + //Start compression ============================== + size_t stream_position_start = file_stream_.tellp(); + size_t source_size = sizeof(char) * bytes_as_string.size(); + size_t destination_capacity = LZ4_compressBound(source_size); + char* destination = new char[destination_capacity]; + size_t destination_size = LZ4_compress_default( + (char*)bytes_as_string.c_str(), destination, source_size, destination_capacity); + file_stream_.write(destination, destination_size); + delete[] destination; + size_t stream_position_end = file_stream_.tellp(); + // end compression =============================== + + midi::Item* item; + switch (split_id) { + case 0: item = dataset_split_protobuf_.add_train(); break; + case 1: item = dataset_split_protobuf_.add_valid(); break; + case 2: item = dataset_split_protobuf_.add_test(); break; + } + item->set_start(stream_position_start); + item->set_end(stream_position_end); + item->set_src_size(source_size); + flush_count_++; + + if (flush_count_ >= 1000) { + writeFile(); + flush_count_ = 0; + }; + } + + void BytesToFile::writeFile() { + file_stream_.flush(); + //TODO: Check if the header stuff actually makes sense... we might not be using the header ever. + header_file_stream_.open(header_filepath_, std::ios::out | std::ios::binary); + if (!dataset_split_protobuf_.SerializeToOstream(&header_file_stream_)) { + std::cerr << "ERROR : Failed to write header file" << std::endl; + } + header_file_stream_.close(); + } + + void BytesToFile::close() + { + writeFile(); + file_stream_.close(); + header_file_stream_.close(); + } +} + diff --git a/src/inference/dataset/jagged.h b/src/inference/dataset/jagged.h new file mode 100644 index 0000000000000000000000000000000000000000..c45c6e9753c87d95180d84570fc63d8dd5144b3a --- /dev/null +++ b/src/inference/dataset/jagged.h @@ -0,0 +1,367 @@ +#include +#include +#include +#include +#include + +#include + +#include "lz4.h" +#include "../../../libraries/protobuf/build/midi.pb.h" +#include "../../common/encoder/encoder_all.h" +#include "../../common/midi_parsing/midi_io.h" +#include "../enum/encoder_types.h" +#include "../../common/data_structures/train_config.h" +#include "../random.h" + +// START OF NAMESPACE +namespace compression { + +template +using matrix = std::vector>; + +template +using tensor = std::vector>>; + +template +class Batcher { +public: + Batcher( int mmaxlen, std::mt19937 *e) { + maxlen = mmaxlen; + batch_maxlen = 0; + batch_size = 0; + engine = e; + } + void add( std::vector &seq ) { + std::vector item; + if (seq.size() > maxlen) { + int off = random_on_range((int)seq.size() - maxlen + 1, engine); + copy(seq.begin() + off, seq.begin() + off + maxlen, back_inserter(item)); + } + else { + copy(seq.begin(), seq.end(), back_inserter(item)); + } + batch.push_back( item ); + batch_maxlen = std::max(item.size(), batch_maxlen); + batch_size++; + } + void pad( T value ) { + for (size_t i=0; i> batch; +}; + + +class Jagged { +public: + Jagged(std::string filepath_) { + filepath = filepath_; + header_filepath = filepath_ + ".header"; + can_write = false; + can_read = false; + flush_count = 0; + num_bars = 4; + min_tracks = 2; + max_tracks = 12; + max_seq_len = 2048; + + engine.seed(time(NULL)); + + encoder = NULL; + } + + void set_seed(int seed) { + srand(seed); // set the seed + engine.seed(seed); + } + + void set_num_bars(int x) { + num_bars = x; + } + + void set_min_tracks(int x) { + min_tracks = x; + } + + void set_max_tracks(int x) { + max_tracks = x; + } + + void set_max_seq_len(int x) { + max_seq_len = x; + } + + void enable_write() { + assert(can_read == false); + if (can_write) { return; } + // check that the current file is empty unless force flag is present ? + fs.open(filepath, std::ios::out | std::ios::binary); + can_write = true; + } + + void enable_read() { + assert(can_write == false); + if (can_read) { return; } + fs.open(filepath, std::ios::in | std::ios::binary); + if (!fs.is_open()) { + throw std::runtime_error("COULD NOT OPEN FILE!"); + } + header_fs.open(header_filepath, std::ios::in | std::ios::binary); + header.ParseFromIstream(&header_fs); + can_read = true; + } + + void append(std::string &s, size_t split_id) { + enable_write(); + + size_t start = fs.tellp(); + // begin compress =============================== + size_t src_size = sizeof(char)*s.size(); + size_t dst_capacity = LZ4_compressBound(src_size); + char* dst = new char[dst_capacity]; + size_t dst_size = LZ4_compress_default( + (char*)s.c_str(), dst, src_size, dst_capacity); + fs.write(dst, dst_size); + delete[] dst; + // end compress ================================= + size_t end = fs.tellp(); + midi::Item *item; + switch (split_id) { + case 0: item = header.add_train(); break; + case 1: item = header.add_valid(); break; + case 2: item = header.add_test(); break; + } + item->set_start(start); + item->set_end(end); + item->set_src_size(src_size); + flush_count++; + + if (flush_count >= 1000) { + flush(); + flush_count = 0; + } + } + + std::string read(size_t index, size_t split_id) { + enable_read(); + + midi::Item item; + switch (split_id) { + case 0: item = header.train(index); break; + case 1: item = header.valid(index); break; + case 2: item = header.test(index); break; + } + size_t csize = item.end() - item.start(); + char* src = new char[csize/sizeof(char)]; + fs.seekg(item.start()); + fs.read(src, csize); + std::string x(item.src_size(), ' '); + LZ4_decompress_safe(src,(char*)x.c_str(),csize,item.src_size()); + delete[] src; + return x; + } + + py::bytes read_bytes(size_t index, size_t split_id) { + return py::bytes(read(index, split_id)); + } + + std::string read_json(size_t index, size_t split_id) { + midi::Piece p; + std::string serialized_data = read(index, split_id); + p.ParseFromString(serialized_data); + std::string json_string; + google::protobuf::util::MessageToJsonString(p, &json_string); + return json_string; + } + + // below is functions for dataset + int select_random_transpose(midi::Piece *p) { + std::tuple pitch_ext = util_protobuf::get_pitch_extents(p); + std::vector choices; + for (int tr=-6; tr<6; tr++) { + if ((std::get<0>(pitch_ext)+tr >= 0) && (std::get<1>(pitch_ext)+tr < 128)) { + choices.push_back( tr ); + } + } + return choices[random_on_range(choices.size(),&engine)]; + } + + void load_random_piece(midi::Piece *p, size_t split_id) { + int nitems = get_split_size(split_id); + int index = random_on_range(nitems, &engine); + std::string serialized_data = read(index, split_id); + p->ParseFromString(serialized_data); + } + + std::string load_random_piece_string(size_t split_id) { + int nitems = get_split_size(split_id); + int index = random_on_range(nitems, &engine); + std::string serialized_data=read(index, split_id); + midi::Piece p; + p.ParseFromString(serialized_data); + std::string json_string; + google::protobuf::util::MessageToJsonString(p, &json_string); + return json_string; + } + + void load_random_segment(midi::Piece *p, size_t split_id, encoder::ENCODER *enc, data_structures::TrainConfig *tc) { + + load_random_piece(p, split_id); + + if (tc->use_microtiming) { + //enc->config->use_microtiming = random_on_unit(&engine) < tc->microtiming; + if (p->internal_metadata_labels().nomml() == 12) { + enc->config->use_microtiming = random_on_unit(&engine) < tc->microtiming; + } + } + + compute_piece_level_attribute_controls(enc->rep,p); + + util_protobuf::select_random_segment( + p, tc->num_bars, tc->min_tracks, tc->max_tracks, &engine); + enc->config->transpose = select_random_transpose(p); + + // 75 % of the time we do bar infill + if (enc->config->both_in_one) { + enc->config->do_multi_fill = random_on_unit(&engine) < .75; + } + + // pick bars for infilling if needed + if (enc->config->do_multi_fill) { + enc->config->multi_fill = util_protobuf::make_bar_mask( + p, tc->max_mask_percentage, &engine); + } + } + + std::string load_random_piece_py(size_t split_id) { + midi::Piece p; + load_random_piece(&p, split_id); + std::string json_string; + google::protobuf::util::MessageToJsonString(p, &json_string); + return json_string; + } + + std::vector load_piece(size_t split_id, enums::ENCODER_TYPE et, data_structures::TrainConfig *tc) { + midi::Piece p; + load_random_piece(&p, split_id); + + std::unique_ptr enc = getEncoder(et); + if (!enc) { + throw std::runtime_error("ENCODER TYPE DOES NOT EXIST"); + } + + util_protobuf::select_random_segment(&p, tc->num_bars, tc->min_tracks, tc->max_tracks, &engine); + + enc->config->transpose = select_random_transpose(&p); + return enc->encode(&p); + } + + std::tuple,matrix> read_batch(int batch_size, size_t split_id, enums::ENCODER_TYPE et, data_structures::TrainConfig *tc) { + enable_read(); + std::unique_ptr enc = getEncoder(et); + if (!enc) { + throw std::runtime_error("ENCODER TYPE DOES NOT EXIST"); + } + + Batcher batch(max_seq_len, &engine); + Batcher att_mask(max_seq_len, &engine); + + // switch number of bars + std::vector num_bar_choices; + if (enc->rep->has_token_type(midi::TOKEN_NUM_BARS)) { + auto it = enc->rep->token_domains.find(midi::TOKEN_NUM_BARS); + for (const auto &v : it->second.input_domain) { + num_bar_choices.push_back( std::get(v) ); + } + } + + while((int)batch.batch_size < batch_size) { + + // pick random number of bars from domain + if (enc->rep->has_token_type(midi::TOKEN_NUM_BARS)) { + int index = random_on_range(num_bar_choices.size(), &engine); + tc->num_bars = num_bar_choices[index]; + } + + try { + midi::Piece p; + load_random_segment(&p, split_id, enc.get(), tc); + std::vector tokens = enc->encode(&p); + std::vector mask(tokens.size(),1); + batch.add( tokens ); + att_mask.add( mask ); + } + catch (const std::exception &exc) + { + std::cerr << exc.what() << std::endl; + } + } + batch.pad(0); + att_mask.pad(0); + return make_tuple(batch.batch, att_mask.batch); + } + + int get_size() { + enable_read(); + return header.train_size() + header.valid_size() + header.test_size(); + } + + int get_split_size(int split_id) { + enable_read(); + switch (split_id) { + case 0 : return header.train_size(); + case 1 : return header.valid_size(); + case 2 : return header.test_size(); + } + return 0; // invalid split id + } + + void flush() { + fs.flush(); + header_fs.open(header_filepath, std::ios::out | std::ios::binary); + if (!header.SerializeToOstream(&header_fs)) { + std::cerr << "ERROR : Failed to write header file" << std::endl; + } + header_fs.close(); + } + + void close() { + flush(); + fs.close(); + header_fs.close(); + can_read = false; + can_write = false; + } + +private: + std::string filepath; + std::string header_filepath; + std::fstream fs; + std::fstream header_fs; + bool can_write; + bool can_read; + midi::Dataset header; + int flush_count; + + int num_bars; + int min_tracks; + int max_tracks; + int max_seq_len; + + std::mt19937 engine; + + std::vector> bstore; + encoder::ENCODER *encoder; +}; + +} +// END OF NAMESPACE + + + diff --git a/src/inference/dataset/lz4.c b/src/inference/dataset/lz4.c new file mode 100644 index 0000000000000000000000000000000000000000..82ab49081152f92a6c3495154b8ada54bae8df07 --- /dev/null +++ b/src/inference/dataset/lz4.c @@ -0,0 +1,2402 @@ +/* + LZ4 - Fast LZ compression algorithm + Copyright (C) 2011-present, Yann Collet. + + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 homepage : http://www.lz4.org + - LZ4 source repository : https://github.com/lz4/lz4 +*/ + +/*-************************************ +* Tuning parameters +**************************************/ +/* + * LZ4_HEAPMODE : + * Select how default compression functions will allocate memory for their hash table, + * in memory stack (0:default, fastest), or in memory heap (1:requires malloc()). + */ +#ifndef LZ4_HEAPMODE +# define LZ4_HEAPMODE 0 +#endif + +/* + * ACCELERATION_DEFAULT : + * Select "acceleration" for LZ4_compress_fast() when parameter value <= 0 + */ +#define ACCELERATION_DEFAULT 1 + + +/*-************************************ +* CPU Feature Detection +**************************************/ +/* LZ4_FORCE_MEMORY_ACCESS + * By default, access to unaligned memory is controlled by `memcpy()`, which is safe and portable. + * Unfortunately, on some target/compiler combinations, the generated assembly is sub-optimal. + * The below switch allow to select different access method for improved performance. + * Method 0 (default) : use `memcpy()`. Safe and portable. + * Method 1 : `__packed` statement. It depends on compiler extension (ie, not portable). + * This method is safe if your compiler supports it, and *generally* as fast or faster than `memcpy`. + * Method 2 : direct access. This method is portable but violate C standard. + * It can generate buggy code on targets which assembly generation depends on alignment. + * But in some circumstances, it's the only known way to get the most performance (ie GCC + ARMv6) + * See https://fastcompression.blogspot.fr/2015/08/accessing-unaligned-memory.html for details. + * Prefer these methods in priority order (0 > 1 > 2) + */ +#ifndef LZ4_FORCE_MEMORY_ACCESS /* can be defined externally */ +# if defined(__GNUC__) && \ + ( defined(__ARM_ARCH_6__) || defined(__ARM_ARCH_6J__) || defined(__ARM_ARCH_6K__) \ + || defined(__ARM_ARCH_6Z__) || defined(__ARM_ARCH_6ZK__) || defined(__ARM_ARCH_6T2__) ) +# define LZ4_FORCE_MEMORY_ACCESS 2 +# elif (defined(__INTEL_COMPILER) && !defined(_WIN32)) || defined(__GNUC__) +# define LZ4_FORCE_MEMORY_ACCESS 1 +# endif +#endif + +/* + * LZ4_FORCE_SW_BITCOUNT + * Define this parameter if your target system or compiler does not support hardware bit count + */ +#if defined(_MSC_VER) && defined(_WIN32_WCE) /* Visual Studio for WinCE doesn't support Hardware bit count */ +# define LZ4_FORCE_SW_BITCOUNT +#endif + + + +/*-************************************ +* Dependency +**************************************/ +/* + * LZ4_SRC_INCLUDED: + * Amalgamation flag, whether lz4.c is included + */ +#ifndef LZ4_SRC_INCLUDED +# define LZ4_SRC_INCLUDED 1 +#endif + +#ifndef LZ4_STATIC_LINKING_ONLY +#define LZ4_STATIC_LINKING_ONLY +#endif + +#ifndef LZ4_DISABLE_DEPRECATE_WARNINGS +#define LZ4_DISABLE_DEPRECATE_WARNINGS /* due to LZ4_decompress_safe_withPrefix64k */ +#endif + +#define LZ4_STATIC_LINKING_ONLY /* LZ4_DISTANCE_MAX */ +#include "lz4.h" +/* see also "memory routines" below */ + + +/*-************************************ +* Compiler Options +**************************************/ +#ifdef _MSC_VER /* Visual Studio */ +# include +# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */ +# pragma warning(disable : 4293) /* disable: C4293: too large shift (32-bits) */ +#endif /* _MSC_VER */ + +#ifndef LZ4_FORCE_INLINE +# ifdef _MSC_VER /* Visual Studio */ +# define LZ4_FORCE_INLINE static __forceinline +# else +# if defined (__cplusplus) || defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* C99 */ +# ifdef __GNUC__ +# define LZ4_FORCE_INLINE static inline __attribute__((always_inline)) +# else +# define LZ4_FORCE_INLINE static inline +# endif +# else +# define LZ4_FORCE_INLINE static +# endif /* __STDC_VERSION__ */ +# endif /* _MSC_VER */ +#endif /* LZ4_FORCE_INLINE */ + +/* LZ4_FORCE_O2_GCC_PPC64LE and LZ4_FORCE_O2_INLINE_GCC_PPC64LE + * gcc on ppc64le generates an unrolled SIMDized loop for LZ4_wildCopy8, + * together with a simple 8-byte copy loop as a fall-back path. + * However, this optimization hurts the decompression speed by >30%, + * because the execution does not go to the optimized loop + * for typical compressible data, and all of the preamble checks + * before going to the fall-back path become useless overhead. + * This optimization happens only with the -O3 flag, and -O2 generates + * a simple 8-byte copy loop. + * With gcc on ppc64le, all of the LZ4_decompress_* and LZ4_wildCopy8 + * functions are annotated with __attribute__((optimize("O2"))), + * and also LZ4_wildCopy8 is forcibly inlined, so that the O2 attribute + * of LZ4_wildCopy8 does not affect the compression speed. + */ +#if defined(__PPC64__) && defined(__LITTLE_ENDIAN__) && defined(__GNUC__) && !defined(__clang__) +# define LZ4_FORCE_O2_GCC_PPC64LE __attribute__((optimize("O2"))) +# define LZ4_FORCE_O2_INLINE_GCC_PPC64LE __attribute__((optimize("O2"))) LZ4_FORCE_INLINE +#else +# define LZ4_FORCE_O2_GCC_PPC64LE +# define LZ4_FORCE_O2_INLINE_GCC_PPC64LE static +#endif + +#if (defined(__GNUC__) && (__GNUC__ >= 3)) || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) || defined(__clang__) +# define expect(expr,value) (__builtin_expect ((expr),(value)) ) +#else +# define expect(expr,value) (expr) +#endif + +#ifndef likely +#define likely(expr) expect((expr) != 0, 1) +#endif +#ifndef unlikely +#define unlikely(expr) expect((expr) != 0, 0) +#endif + + +/*-************************************ +* Memory routines +**************************************/ +#include /* malloc, calloc, free */ +#define ALLOC(s) malloc(s) +#define ALLOC_AND_ZERO(s) calloc(1,s) +#define FREEMEM(p) free(p) +#include /* memset, memcpy */ +#define MEM_INIT(p,v,s) memset((p),(v),(s)) + + +/*-************************************ +* Common Constants +**************************************/ +#define MINMATCH 4 + +#define WILDCOPYLENGTH 8 +#define LASTLITERALS 5 /* see ../doc/lz4_Block_format.md#parsing-restrictions */ +#define MFLIMIT 12 /* see ../doc/lz4_Block_format.md#parsing-restrictions */ +#define MATCH_SAFEGUARD_DISTANCE ((2*WILDCOPYLENGTH) - MINMATCH) /* ensure it's possible to write 2 x wildcopyLength without overflowing output buffer */ +#define FASTLOOP_SAFE_DISTANCE 64 +static const int LZ4_minLength = (MFLIMIT+1); + +#define KB *(1 <<10) +#define MB *(1 <<20) +#define GB *(1U<<30) + +#define LZ4_DISTANCE_ABSOLUTE_MAX 65535 +#if (LZ4_DISTANCE_MAX > LZ4_DISTANCE_ABSOLUTE_MAX) /* max supported by LZ4 format */ +# error "LZ4_DISTANCE_MAX is too big : must be <= 65535" +#endif + +#define ML_BITS 4 +#define ML_MASK ((1U<=1) +# include +#else +# ifndef assert +# define assert(condition) ((void)0) +# endif +#endif + +#define LZ4_STATIC_ASSERT(c) { enum { LZ4_static_assert = 1/(int)(!!(c)) }; } /* use after variable declarations */ + +#if defined(LZ4_DEBUG) && (LZ4_DEBUG>=2) +# include + static int g_debuglog_enable = 1; +# define DEBUGLOG(l, ...) { \ + if ((g_debuglog_enable) && (l<=LZ4_DEBUG)) { \ + fprintf(stderr, __FILE__ ": "); \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, " \n"); \ + } } +#else +# define DEBUGLOG(l, ...) {} /* disabled */ +#endif + + +/*-************************************ +* Types +**************************************/ +#if defined(__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +# include + typedef uint8_t BYTE; + typedef uint16_t U16; + typedef uint32_t U32; + typedef int32_t S32; + typedef uint64_t U64; + typedef uintptr_t uptrval; +#else +# include +# if UINT_MAX != 4294967295UL +# error "LZ4 code (when not C++ or C99) assumes that sizeof(int) == 4" +# endif + typedef unsigned char BYTE; + typedef unsigned short U16; + typedef unsigned int U32; + typedef signed int S32; + typedef unsigned long long U64; + typedef size_t uptrval; /* generally true, except OpenVMS-64 */ +#endif + +#if defined(__x86_64__) + typedef U64 reg_t; /* 64-bits in x32 mode */ +#else + typedef size_t reg_t; /* 32-bits in x32 mode */ +#endif + +typedef enum { + notLimited = 0, + limitedOutput = 1, + fillOutput = 2 +} limitedOutput_directive; + + +/*-************************************ +* Reading and writing into memory +**************************************/ +static unsigned LZ4_isLittleEndian(void) +{ + const union { U32 u; BYTE c[4]; } one = { 1 }; /* don't use static : performance detrimental */ + return one.c[0]; +} + + +#if defined(LZ4_FORCE_MEMORY_ACCESS) && (LZ4_FORCE_MEMORY_ACCESS==2) +/* lie to the compiler about data alignment; use with caution */ + +static U16 LZ4_read16(const void* memPtr) { return *(const U16*) memPtr; } +static U32 LZ4_read32(const void* memPtr) { return *(const U32*) memPtr; } +static reg_t LZ4_read_ARCH(const void* memPtr) { return *(const reg_t*) memPtr; } + +static void LZ4_write16(void* memPtr, U16 value) { *(U16*)memPtr = value; } +static void LZ4_write32(void* memPtr, U32 value) { *(U32*)memPtr = value; } + +#elif defined(LZ4_FORCE_MEMORY_ACCESS) && (LZ4_FORCE_MEMORY_ACCESS==1) + +/* __pack instructions are safer, but compiler specific, hence potentially problematic for some compilers */ +/* currently only defined for gcc and icc */ +typedef union { U16 u16; U32 u32; reg_t uArch; } __attribute__((packed)) unalign; + +static U16 LZ4_read16(const void* ptr) { return ((const unalign*)ptr)->u16; } +static U32 LZ4_read32(const void* ptr) { return ((const unalign*)ptr)->u32; } +static reg_t LZ4_read_ARCH(const void* ptr) { return ((const unalign*)ptr)->uArch; } + +static void LZ4_write16(void* memPtr, U16 value) { ((unalign*)memPtr)->u16 = value; } +static void LZ4_write32(void* memPtr, U32 value) { ((unalign*)memPtr)->u32 = value; } + +#else /* safe and portable access using memcpy() */ + +static U16 LZ4_read16(const void* memPtr) +{ + U16 val; memcpy(&val, memPtr, sizeof(val)); return val; +} + +static U32 LZ4_read32(const void* memPtr) +{ + U32 val; memcpy(&val, memPtr, sizeof(val)); return val; +} + +static reg_t LZ4_read_ARCH(const void* memPtr) +{ + reg_t val; memcpy(&val, memPtr, sizeof(val)); return val; +} + +static void LZ4_write16(void* memPtr, U16 value) +{ + memcpy(memPtr, &value, sizeof(value)); +} + +static void LZ4_write32(void* memPtr, U32 value) +{ + memcpy(memPtr, &value, sizeof(value)); +} + +#endif /* LZ4_FORCE_MEMORY_ACCESS */ + + +static U16 LZ4_readLE16(const void* memPtr) +{ + if (LZ4_isLittleEndian()) { + return LZ4_read16(memPtr); + } else { + const BYTE* p = (const BYTE*)memPtr; + return (U16)((U16)p[0] + (p[1]<<8)); + } +} + +static void LZ4_writeLE16(void* memPtr, U16 value) +{ + if (LZ4_isLittleEndian()) { + LZ4_write16(memPtr, value); + } else { + BYTE* p = (BYTE*)memPtr; + p[0] = (BYTE) value; + p[1] = (BYTE)(value>>8); + } +} + +/* customized variant of memcpy, which can overwrite up to 8 bytes beyond dstEnd */ +LZ4_FORCE_O2_INLINE_GCC_PPC64LE +void LZ4_wildCopy8(void* dstPtr, const void* srcPtr, void* dstEnd) +{ + BYTE* d = (BYTE*)dstPtr; + const BYTE* s = (const BYTE*)srcPtr; + BYTE* const e = (BYTE*)dstEnd; + + do { memcpy(d,s,8); d+=8; s+=8; } while (d= 16. */ +LZ4_FORCE_O2_INLINE_GCC_PPC64LE void +LZ4_wildCopy32(void* dstPtr, const void* srcPtr, void* dstEnd) +{ + BYTE* d = (BYTE*)dstPtr; + const BYTE* s = (const BYTE*)srcPtr; + BYTE* const e = (BYTE*)dstEnd; + + do { memcpy(d,s,16); memcpy(d+16,s+16,16); d+=32; s+=32; } while (d= dstPtr + MINMATCH + * - there is at least 8 bytes available to write after dstEnd */ +LZ4_FORCE_O2_INLINE_GCC_PPC64LE void +LZ4_memcpy_using_offset(BYTE* dstPtr, const BYTE* srcPtr, BYTE* dstEnd, const size_t offset) +{ + BYTE v[8]; + + assert(dstEnd >= dstPtr + MINMATCH); + LZ4_write32(dstPtr, 0); /* silence an msan warning when offset==0 */ + + switch(offset) { + case 1: + memset(v, *srcPtr, 8); + break; + case 2: + memcpy(v, srcPtr, 2); + memcpy(&v[2], srcPtr, 2); + memcpy(&v[4], &v[0], 4); + break; + case 4: + memcpy(v, srcPtr, 4); + memcpy(&v[4], srcPtr, 4); + break; + default: + LZ4_memcpy_using_offset_base(dstPtr, srcPtr, dstEnd, offset); + return; + } + + memcpy(dstPtr, v, 8); + dstPtr += 8; + while (dstPtr < dstEnd) { + memcpy(dstPtr, v, 8); + dstPtr += 8; + } +} +#endif + + +/*-************************************ +* Common functions +**************************************/ +static unsigned LZ4_NbCommonBytes (reg_t val) +{ + if (LZ4_isLittleEndian()) { + if (sizeof(val)==8) { +# if defined(_MSC_VER) && defined(_WIN64) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanForward64( &r, (U64)val ); + return (int)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_ctzll((U64)val) >> 3; +# else + static const int DeBruijnBytePos[64] = { 0, 0, 0, 0, 0, 1, 1, 2, + 0, 3, 1, 3, 1, 4, 2, 7, + 0, 2, 3, 6, 1, 5, 3, 5, + 1, 3, 4, 4, 2, 5, 6, 7, + 7, 0, 1, 2, 3, 3, 4, 6, + 2, 6, 5, 5, 3, 4, 5, 6, + 7, 1, 2, 4, 6, 4, 4, 5, + 7, 2, 6, 5, 7, 6, 7, 7 }; + return DeBruijnBytePos[((U64)((val & -(long long)val) * 0x0218A392CDABBD3FULL)) >> 58]; +# endif + } else /* 32 bits */ { +# if defined(_MSC_VER) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r; + _BitScanForward( &r, (U32)val ); + return (int)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_ctz((U32)val) >> 3; +# else + static const int DeBruijnBytePos[32] = { 0, 0, 3, 0, 3, 1, 3, 0, + 3, 2, 2, 1, 3, 2, 0, 1, + 3, 3, 1, 2, 2, 2, 2, 0, + 3, 1, 2, 0, 1, 0, 1, 1 }; + return DeBruijnBytePos[((U32)((val & -(S32)val) * 0x077CB531U)) >> 27]; +# endif + } + } else /* Big Endian CPU */ { + if (sizeof(val)==8) { /* 64-bits */ +# if defined(_MSC_VER) && defined(_WIN64) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanReverse64( &r, val ); + return (unsigned)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_clzll((U64)val) >> 3; +# else + static const U32 by32 = sizeof(val)*4; /* 32 on 64 bits (goal), 16 on 32 bits. + Just to avoid some static analyzer complaining about shift by 32 on 32-bits target. + Note that this code path is never triggered in 32-bits mode. */ + unsigned r; + if (!(val>>by32)) { r=4; } else { r=0; val>>=by32; } + if (!(val>>16)) { r+=2; val>>=8; } else { val>>=24; } + r += (!val); + return r; +# endif + } else /* 32 bits */ { +# if defined(_MSC_VER) && !defined(LZ4_FORCE_SW_BITCOUNT) + unsigned long r = 0; + _BitScanReverse( &r, (unsigned long)val ); + return (unsigned)(r>>3); +# elif (defined(__clang__) || (defined(__GNUC__) && (__GNUC__>=3))) && !defined(LZ4_FORCE_SW_BITCOUNT) + return (unsigned)__builtin_clz((U32)val) >> 3; +# else + unsigned r; + if (!(val>>16)) { r=2; val>>=8; } else { r=0; val>>=24; } + r += (!val); + return r; +# endif + } + } +} + +#define STEPSIZE sizeof(reg_t) +LZ4_FORCE_INLINE +unsigned LZ4_count(const BYTE* pIn, const BYTE* pMatch, const BYTE* pInLimit) +{ + const BYTE* const pStart = pIn; + + if (likely(pIn < pInLimit-(STEPSIZE-1))) { + reg_t const diff = LZ4_read_ARCH(pMatch) ^ LZ4_read_ARCH(pIn); + if (!diff) { + pIn+=STEPSIZE; pMatch+=STEPSIZE; + } else { + return LZ4_NbCommonBytes(diff); + } } + + while (likely(pIn < pInLimit-(STEPSIZE-1))) { + reg_t const diff = LZ4_read_ARCH(pMatch) ^ LZ4_read_ARCH(pIn); + if (!diff) { pIn+=STEPSIZE; pMatch+=STEPSIZE; continue; } + pIn += LZ4_NbCommonBytes(diff); + return (unsigned)(pIn - pStart); + } + + if ((STEPSIZE==8) && (pIn<(pInLimit-3)) && (LZ4_read32(pMatch) == LZ4_read32(pIn))) { pIn+=4; pMatch+=4; } + if ((pIn<(pInLimit-1)) && (LZ4_read16(pMatch) == LZ4_read16(pIn))) { pIn+=2; pMatch+=2; } + if ((pIn compression run slower on incompressible data */ + + +/*-************************************ +* Local Structures and types +**************************************/ +typedef enum { clearedTable = 0, byPtr, byU32, byU16 } tableType_t; + +/** + * This enum distinguishes several different modes of accessing previous + * content in the stream. + * + * - noDict : There is no preceding content. + * - withPrefix64k : Table entries up to ctx->dictSize before the current blob + * blob being compressed are valid and refer to the preceding + * content (of length ctx->dictSize), which is available + * contiguously preceding in memory the content currently + * being compressed. + * - usingExtDict : Like withPrefix64k, but the preceding content is somewhere + * else in memory, starting at ctx->dictionary with length + * ctx->dictSize. + * - usingDictCtx : Like usingExtDict, but everything concerning the preceding + * content is in a separate context, pointed to by + * ctx->dictCtx. ctx->dictionary, ctx->dictSize, and table + * entries in the current context that refer to positions + * preceding the beginning of the current compression are + * ignored. Instead, ctx->dictCtx->dictionary and ctx->dictCtx + * ->dictSize describe the location and size of the preceding + * content, and matches are found by looking in the ctx + * ->dictCtx->hashTable. + */ +typedef enum { noDict = 0, withPrefix64k, usingExtDict, usingDictCtx } dict_directive; +typedef enum { noDictIssue = 0, dictSmall } dictIssue_directive; + + +/*-************************************ +* Local Utils +**************************************/ +int LZ4_versionNumber (void) { return LZ4_VERSION_NUMBER; } +const char* LZ4_versionString(void) { return LZ4_VERSION_STRING; } +int LZ4_compressBound(int isize) { return LZ4_COMPRESSBOUND(isize); } +int LZ4_sizeofState() { return LZ4_STREAMSIZE; } + + +/*-************************************ +* Internal Definitions used in Tests +**************************************/ +#if defined (__cplusplus) +extern "C" { +#endif + +int LZ4_compress_forceExtDict (LZ4_stream_t* LZ4_dict, const char* source, char* dest, int srcSize); + +int LZ4_decompress_safe_forceExtDict(const char* source, char* dest, + int compressedSize, int maxOutputSize, + const void* dictStart, size_t dictSize); + +#if defined (__cplusplus) +} +#endif + +/*-****************************** +* Compression functions +********************************/ +LZ4_FORCE_INLINE U32 LZ4_hash4(U32 sequence, tableType_t const tableType) +{ + if (tableType == byU16) + return ((sequence * 2654435761U) >> ((MINMATCH*8)-(LZ4_HASHLOG+1))); + else + return ((sequence * 2654435761U) >> ((MINMATCH*8)-LZ4_HASHLOG)); +} + +LZ4_FORCE_INLINE U32 LZ4_hash5(U64 sequence, tableType_t const tableType) +{ + const U32 hashLog = (tableType == byU16) ? LZ4_HASHLOG+1 : LZ4_HASHLOG; + if (LZ4_isLittleEndian()) { + const U64 prime5bytes = 889523592379ULL; + return (U32)(((sequence << 24) * prime5bytes) >> (64 - hashLog)); + } else { + const U64 prime8bytes = 11400714785074694791ULL; + return (U32)(((sequence >> 24) * prime8bytes) >> (64 - hashLog)); + } +} + +LZ4_FORCE_INLINE U32 LZ4_hashPosition(const void* const p, tableType_t const tableType) +{ + if ((sizeof(reg_t)==8) && (tableType != byU16)) return LZ4_hash5(LZ4_read_ARCH(p), tableType); + return LZ4_hash4(LZ4_read32(p), tableType); +} + +LZ4_FORCE_INLINE void LZ4_clearHash(U32 h, void* tableBase, tableType_t const tableType) +{ + switch (tableType) + { + default: /* fallthrough */ + case clearedTable: { /* illegal! */ assert(0); return; } + case byPtr: { const BYTE** hashTable = (const BYTE**)tableBase; hashTable[h] = NULL; return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = 0; return; } + case byU16: { U16* hashTable = (U16*) tableBase; hashTable[h] = 0; return; } + } +} + +LZ4_FORCE_INLINE void LZ4_putIndexOnHash(U32 idx, U32 h, void* tableBase, tableType_t const tableType) +{ + switch (tableType) + { + default: /* fallthrough */ + case clearedTable: /* fallthrough */ + case byPtr: { /* illegal! */ assert(0); return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = idx; return; } + case byU16: { U16* hashTable = (U16*) tableBase; assert(idx < 65536); hashTable[h] = (U16)idx; return; } + } +} + +LZ4_FORCE_INLINE void LZ4_putPositionOnHash(const BYTE* p, U32 h, + void* tableBase, tableType_t const tableType, + const BYTE* srcBase) +{ + switch (tableType) + { + case clearedTable: { /* illegal! */ assert(0); return; } + case byPtr: { const BYTE** hashTable = (const BYTE**)tableBase; hashTable[h] = p; return; } + case byU32: { U32* hashTable = (U32*) tableBase; hashTable[h] = (U32)(p-srcBase); return; } + case byU16: { U16* hashTable = (U16*) tableBase; hashTable[h] = (U16)(p-srcBase); return; } + } +} + +LZ4_FORCE_INLINE void LZ4_putPosition(const BYTE* p, void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + U32 const h = LZ4_hashPosition(p, tableType); + LZ4_putPositionOnHash(p, h, tableBase, tableType, srcBase); +} + +/* LZ4_getIndexOnHash() : + * Index of match position registered in hash table. + * hash position must be calculated by using base+index, or dictBase+index. + * Assumption 1 : only valid if tableType == byU32 or byU16. + * Assumption 2 : h is presumed valid (within limits of hash table) + */ +LZ4_FORCE_INLINE U32 LZ4_getIndexOnHash(U32 h, const void* tableBase, tableType_t tableType) +{ + LZ4_STATIC_ASSERT(LZ4_MEMORY_USAGE > 2); + if (tableType == byU32) { + const U32* const hashTable = (const U32*) tableBase; + assert(h < (1U << (LZ4_MEMORY_USAGE-2))); + return hashTable[h]; + } + if (tableType == byU16) { + const U16* const hashTable = (const U16*) tableBase; + assert(h < (1U << (LZ4_MEMORY_USAGE-1))); + return hashTable[h]; + } + assert(0); return 0; /* forbidden case */ +} + +static const BYTE* LZ4_getPositionOnHash(U32 h, const void* tableBase, tableType_t tableType, const BYTE* srcBase) +{ + if (tableType == byPtr) { const BYTE* const* hashTable = (const BYTE* const*) tableBase; return hashTable[h]; } + if (tableType == byU32) { const U32* const hashTable = (const U32*) tableBase; return hashTable[h] + srcBase; } + { const U16* const hashTable = (const U16*) tableBase; return hashTable[h] + srcBase; } /* default, to ensure a return */ +} + +LZ4_FORCE_INLINE const BYTE* +LZ4_getPosition(const BYTE* p, + const void* tableBase, tableType_t tableType, + const BYTE* srcBase) +{ + U32 const h = LZ4_hashPosition(p, tableType); + return LZ4_getPositionOnHash(h, tableBase, tableType, srcBase); +} + +LZ4_FORCE_INLINE void +LZ4_prepareTable(LZ4_stream_t_internal* const cctx, + const int inputSize, + const tableType_t tableType) { + /* If compression failed during the previous step, then the context + * is marked as dirty, therefore, it has to be fully reset. + */ + if (cctx->dirty) { + DEBUGLOG(5, "LZ4_prepareTable: Full reset for %p", cctx); + MEM_INIT(cctx, 0, sizeof(LZ4_stream_t_internal)); + return; + } + + /* If the table hasn't been used, it's guaranteed to be zeroed out, and is + * therefore safe to use no matter what mode we're in. Otherwise, we figure + * out if it's safe to leave as is or whether it needs to be reset. + */ + if (cctx->tableType != clearedTable) { + assert(inputSize >= 0); + if (cctx->tableType != tableType + || ((tableType == byU16) && cctx->currentOffset + (unsigned)inputSize >= 0xFFFFU) + || ((tableType == byU32) && cctx->currentOffset > 1 GB) + || tableType == byPtr + || inputSize >= 4 KB) + { + DEBUGLOG(4, "LZ4_prepareTable: Resetting table in %p", cctx); + MEM_INIT(cctx->hashTable, 0, LZ4_HASHTABLESIZE); + cctx->currentOffset = 0; + cctx->tableType = clearedTable; + } else { + DEBUGLOG(4, "LZ4_prepareTable: Re-use hash table (no reset)"); + } + } + + /* Adding a gap, so all previous entries are > LZ4_DISTANCE_MAX back, is faster + * than compressing without a gap. However, compressing with + * currentOffset == 0 is faster still, so we preserve that case. + */ + if (cctx->currentOffset != 0 && tableType == byU32) { + DEBUGLOG(5, "LZ4_prepareTable: adding 64KB to currentOffset"); + cctx->currentOffset += 64 KB; + } + + /* Finally, clear history */ + cctx->dictCtx = NULL; + cctx->dictionary = NULL; + cctx->dictSize = 0; +} + +/** LZ4_compress_generic() : + inlined, to ensure branches are decided at compilation time */ +LZ4_FORCE_INLINE int LZ4_compress_generic( + LZ4_stream_t_internal* const cctx, + const char* const source, + char* const dest, + const int inputSize, + int *inputConsumed, /* only written when outputDirective == fillOutput */ + const int maxOutputSize, + const limitedOutput_directive outputDirective, + const tableType_t tableType, + const dict_directive dictDirective, + const dictIssue_directive dictIssue, + const int acceleration) +{ + int result; + const BYTE* ip = (const BYTE*) source; + + U32 const startIndex = cctx->currentOffset; + const BYTE* base = (const BYTE*) source - startIndex; + const BYTE* lowLimit; + + const LZ4_stream_t_internal* dictCtx = (const LZ4_stream_t_internal*) cctx->dictCtx; + const BYTE* const dictionary = + dictDirective == usingDictCtx ? dictCtx->dictionary : cctx->dictionary; + const U32 dictSize = + dictDirective == usingDictCtx ? dictCtx->dictSize : cctx->dictSize; + const U32 dictDelta = (dictDirective == usingDictCtx) ? startIndex - dictCtx->currentOffset : 0; /* make indexes in dictCtx comparable with index in current context */ + + int const maybe_extMem = (dictDirective == usingExtDict) || (dictDirective == usingDictCtx); + U32 const prefixIdxLimit = startIndex - dictSize; /* used when dictDirective == dictSmall */ + const BYTE* const dictEnd = dictionary + dictSize; + const BYTE* anchor = (const BYTE*) source; + const BYTE* const iend = ip + inputSize; + const BYTE* const mflimitPlusOne = iend - MFLIMIT + 1; + const BYTE* const matchlimit = iend - LASTLITERALS; + + /* the dictCtx currentOffset is indexed on the start of the dictionary, + * while a dictionary in the current context precedes the currentOffset */ + const BYTE* dictBase = (dictDirective == usingDictCtx) ? + dictionary + dictSize - dictCtx->currentOffset : + dictionary + dictSize - startIndex; + + BYTE* op = (BYTE*) dest; + BYTE* const olimit = op + maxOutputSize; + + U32 offset = 0; + U32 forwardH; + + DEBUGLOG(5, "LZ4_compress_generic: srcSize=%i, tableType=%u", inputSize, tableType); + /* If init conditions are not met, we don't have to mark stream + * as having dirty context, since no action was taken yet */ + if (outputDirective == fillOutput && maxOutputSize < 1) { return 0; } /* Impossible to store anything */ + if ((U32)inputSize > (U32)LZ4_MAX_INPUT_SIZE) { return 0; } /* Unsupported inputSize, too large (or negative) */ + if ((tableType == byU16) && (inputSize>=LZ4_64Klimit)) { return 0; } /* Size too large (not within 64K limit) */ + if (tableType==byPtr) assert(dictDirective==noDict); /* only supported use case with byPtr */ + assert(acceleration >= 1); + + lowLimit = (const BYTE*)source - (dictDirective == withPrefix64k ? dictSize : 0); + + /* Update context state */ + if (dictDirective == usingDictCtx) { + /* Subsequent linked blocks can't use the dictionary. */ + /* Instead, they use the block we just compressed. */ + cctx->dictCtx = NULL; + cctx->dictSize = (U32)inputSize; + } else { + cctx->dictSize += (U32)inputSize; + } + cctx->currentOffset += (U32)inputSize; + cctx->tableType = (U16)tableType; + + if (inputSizehashTable, tableType, base); + ip++; forwardH = LZ4_hashPosition(ip, tableType); + + /* Main Loop */ + for ( ; ; ) { + const BYTE* match; + BYTE* token; + const BYTE* filledIp; + + /* Find a match */ + if (tableType == byPtr) { + const BYTE* forwardIp = ip; + int step = 1; + int searchMatchNb = acceleration << LZ4_skipTrigger; + do { + U32 const h = forwardH; + ip = forwardIp; + forwardIp += step; + step = (searchMatchNb++ >> LZ4_skipTrigger); + + if (unlikely(forwardIp > mflimitPlusOne)) goto _last_literals; + assert(ip < mflimitPlusOne); + + match = LZ4_getPositionOnHash(h, cctx->hashTable, tableType, base); + forwardH = LZ4_hashPosition(forwardIp, tableType); + LZ4_putPositionOnHash(ip, h, cctx->hashTable, tableType, base); + + } while ( (match+LZ4_DISTANCE_MAX < ip) + || (LZ4_read32(match) != LZ4_read32(ip)) ); + + } else { /* byU32, byU16 */ + + const BYTE* forwardIp = ip; + int step = 1; + int searchMatchNb = acceleration << LZ4_skipTrigger; + do { + U32 const h = forwardH; + U32 const current = (U32)(forwardIp - base); + U32 matchIndex = LZ4_getIndexOnHash(h, cctx->hashTable, tableType); + assert(matchIndex <= current); + assert(forwardIp - base < (ptrdiff_t)(2 GB - 1)); + ip = forwardIp; + forwardIp += step; + step = (searchMatchNb++ >> LZ4_skipTrigger); + + if (unlikely(forwardIp > mflimitPlusOne)) goto _last_literals; + assert(ip < mflimitPlusOne); + + if (dictDirective == usingDictCtx) { + if (matchIndex < startIndex) { + /* there was no match, try the dictionary */ + assert(tableType == byU32); + matchIndex = LZ4_getIndexOnHash(h, dictCtx->hashTable, byU32); + match = dictBase + matchIndex; + matchIndex += dictDelta; /* make dictCtx index comparable with current context */ + lowLimit = dictionary; + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; + } + } else if (dictDirective==usingExtDict) { + if (matchIndex < startIndex) { + DEBUGLOG(7, "extDict candidate: matchIndex=%5u < startIndex=%5u", matchIndex, startIndex); + assert(startIndex - matchIndex >= MINMATCH); + match = dictBase + matchIndex; + lowLimit = dictionary; + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; + } + } else { /* single continuous memory segment */ + match = base + matchIndex; + } + forwardH = LZ4_hashPosition(forwardIp, tableType); + LZ4_putIndexOnHash(current, h, cctx->hashTable, tableType); + + DEBUGLOG(7, "candidate at pos=%u (offset=%u \n", matchIndex, current - matchIndex); + if ((dictIssue == dictSmall) && (matchIndex < prefixIdxLimit)) { continue; } /* match outside of valid area */ + assert(matchIndex < current); + if ( ((tableType != byU16) || (LZ4_DISTANCE_MAX < LZ4_DISTANCE_ABSOLUTE_MAX)) + && (matchIndex+LZ4_DISTANCE_MAX < current)) { + continue; + } /* too far */ + assert((current - matchIndex) <= LZ4_DISTANCE_MAX); /* match now expected within distance */ + + if (LZ4_read32(match) == LZ4_read32(ip)) { + if (maybe_extMem) offset = current - matchIndex; + break; /* match found */ + } + + } while(1); + } + + /* Catch up */ + filledIp = ip; + while (((ip>anchor) & (match > lowLimit)) && (unlikely(ip[-1]==match[-1]))) { ip--; match--; } + + /* Encode Literals */ + { unsigned const litLength = (unsigned)(ip - anchor); + token = op++; + if ((outputDirective == limitedOutput) && /* Check output buffer overflow */ + (unlikely(op + litLength + (2 + 1 + LASTLITERALS) + (litLength/255) > olimit)) ) { + return 0; /* cannot compress within `dst` budget. Stored indexes in hash table are nonetheless fine */ + } + if ((outputDirective == fillOutput) && + (unlikely(op + (litLength+240)/255 /* litlen */ + litLength /* literals */ + 2 /* offset */ + 1 /* token */ + MFLIMIT - MINMATCH /* min last literals so last match is <= end - MFLIMIT */ > olimit))) { + op--; + goto _last_literals; + } + if (litLength >= RUN_MASK) { + int len = (int)(litLength - RUN_MASK); + *token = (RUN_MASK<= 255 ; len-=255) *op++ = 255; + *op++ = (BYTE)len; + } + else *token = (BYTE)(litLength< olimit)) { + /* the match was too close to the end, rewind and go to last literals */ + op = token; + goto _last_literals; + } + + /* Encode Offset */ + if (maybe_extMem) { /* static test */ + DEBUGLOG(6, " with offset=%u (ext if > %i)", offset, (int)(ip - (const BYTE*)source)); + assert(offset <= LZ4_DISTANCE_MAX && offset > 0); + LZ4_writeLE16(op, (U16)offset); op+=2; + } else { + DEBUGLOG(6, " with offset=%u (same segment)", (U32)(ip - match)); + assert(ip-match <= LZ4_DISTANCE_MAX); + LZ4_writeLE16(op, (U16)(ip - match)); op+=2; + } + + /* Encode MatchLength */ + { unsigned matchCode; + + if ( (dictDirective==usingExtDict || dictDirective==usingDictCtx) + && (lowLimit==dictionary) /* match within extDict */ ) { + const BYTE* limit = ip + (dictEnd-match); + assert(dictEnd > match); + if (limit > matchlimit) limit = matchlimit; + matchCode = LZ4_count(ip+MINMATCH, match+MINMATCH, limit); + ip += (size_t)matchCode + MINMATCH; + if (ip==limit) { + unsigned const more = LZ4_count(limit, (const BYTE*)source, matchlimit); + matchCode += more; + ip += more; + } + DEBUGLOG(6, " with matchLength=%u starting in extDict", matchCode+MINMATCH); + } else { + matchCode = LZ4_count(ip+MINMATCH, match+MINMATCH, matchlimit); + ip += (size_t)matchCode + MINMATCH; + DEBUGLOG(6, " with matchLength=%u", matchCode+MINMATCH); + } + + if ((outputDirective) && /* Check output buffer overflow */ + (unlikely(op + (1 + LASTLITERALS) + (matchCode+240)/255 > olimit)) ) { + if (outputDirective == fillOutput) { + /* Match description too long : reduce it */ + U32 newMatchCode = 15 /* in token */ - 1 /* to avoid needing a zero byte */ + ((U32)(olimit - op) - 1 - LASTLITERALS) * 255; + ip -= matchCode - newMatchCode; + assert(newMatchCode < matchCode); + matchCode = newMatchCode; + if (unlikely(ip <= filledIp)) { + /* We have already filled up to filledIp so if ip ends up less than filledIp + * we have positions in the hash table beyond the current position. This is + * a problem if we reuse the hash table. So we have to remove these positions + * from the hash table. + */ + const BYTE* ptr; + DEBUGLOG(5, "Clearing %u positions", (U32)(filledIp - ip)); + for (ptr = ip; ptr <= filledIp; ++ptr) { + U32 const h = LZ4_hashPosition(ptr, tableType); + LZ4_clearHash(h, cctx->hashTable, tableType); + } + } + } else { + assert(outputDirective == limitedOutput); + return 0; /* cannot compress within `dst` budget. Stored indexes in hash table are nonetheless fine */ + } + } + if (matchCode >= ML_MASK) { + *token += ML_MASK; + matchCode -= ML_MASK; + LZ4_write32(op, 0xFFFFFFFF); + while (matchCode >= 4*255) { + op+=4; + LZ4_write32(op, 0xFFFFFFFF); + matchCode -= 4*255; + } + op += matchCode / 255; + *op++ = (BYTE)(matchCode % 255); + } else + *token += (BYTE)(matchCode); + } + /* Ensure we have enough space for the last literals. */ + assert(!(outputDirective == fillOutput && op + 1 + LASTLITERALS > olimit)); + + anchor = ip; + + /* Test end of chunk */ + if (ip >= mflimitPlusOne) break; + + /* Fill table */ + LZ4_putPosition(ip-2, cctx->hashTable, tableType, base); + + /* Test next position */ + if (tableType == byPtr) { + + match = LZ4_getPosition(ip, cctx->hashTable, tableType, base); + LZ4_putPosition(ip, cctx->hashTable, tableType, base); + if ( (match+LZ4_DISTANCE_MAX >= ip) + && (LZ4_read32(match) == LZ4_read32(ip)) ) + { token=op++; *token=0; goto _next_match; } + + } else { /* byU32, byU16 */ + + U32 const h = LZ4_hashPosition(ip, tableType); + U32 const current = (U32)(ip-base); + U32 matchIndex = LZ4_getIndexOnHash(h, cctx->hashTable, tableType); + assert(matchIndex < current); + if (dictDirective == usingDictCtx) { + if (matchIndex < startIndex) { + /* there was no match, try the dictionary */ + matchIndex = LZ4_getIndexOnHash(h, dictCtx->hashTable, byU32); + match = dictBase + matchIndex; + lowLimit = dictionary; /* required for match length counter */ + matchIndex += dictDelta; + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; /* required for match length counter */ + } + } else if (dictDirective==usingExtDict) { + if (matchIndex < startIndex) { + match = dictBase + matchIndex; + lowLimit = dictionary; /* required for match length counter */ + } else { + match = base + matchIndex; + lowLimit = (const BYTE*)source; /* required for match length counter */ + } + } else { /* single memory segment */ + match = base + matchIndex; + } + LZ4_putIndexOnHash(current, h, cctx->hashTable, tableType); + assert(matchIndex < current); + if ( ((dictIssue==dictSmall) ? (matchIndex >= prefixIdxLimit) : 1) + && (((tableType==byU16) && (LZ4_DISTANCE_MAX == LZ4_DISTANCE_ABSOLUTE_MAX)) ? 1 : (matchIndex+LZ4_DISTANCE_MAX >= current)) + && (LZ4_read32(match) == LZ4_read32(ip)) ) { + token=op++; + *token=0; + if (maybe_extMem) offset = current - matchIndex; + DEBUGLOG(6, "seq.start:%i, literals=%u, match.start:%i", + (int)(anchor-(const BYTE*)source), 0, (int)(ip-(const BYTE*)source)); + goto _next_match; + } + } + + /* Prepare next loop */ + forwardH = LZ4_hashPosition(++ip, tableType); + + } + +_last_literals: + /* Encode Last Literals */ + { size_t lastRun = (size_t)(iend - anchor); + if ( (outputDirective) && /* Check output buffer overflow */ + (op + lastRun + 1 + ((lastRun+255-RUN_MASK)/255) > olimit)) { + if (outputDirective == fillOutput) { + /* adapt lastRun to fill 'dst' */ + assert(olimit >= op); + lastRun = (size_t)(olimit-op) - 1; + lastRun -= (lastRun+240)/255; + } else { + assert(outputDirective == limitedOutput); + return 0; /* cannot compress within `dst` budget. Stored indexes in hash table are nonetheless fine */ + } + } + if (lastRun >= RUN_MASK) { + size_t accumulator = lastRun - RUN_MASK; + *op++ = RUN_MASK << ML_BITS; + for(; accumulator >= 255 ; accumulator-=255) *op++ = 255; + *op++ = (BYTE) accumulator; + } else { + *op++ = (BYTE)(lastRun< 0); + return result; +} + + +int LZ4_compress_fast_extState(void* state, const char* source, char* dest, int inputSize, int maxOutputSize, int acceleration) +{ + LZ4_stream_t_internal* const ctx = & LZ4_initStream(state, sizeof(LZ4_stream_t)) -> internal_donotuse; + assert(ctx != NULL); + if (acceleration < 1) acceleration = ACCELERATION_DEFAULT; + if (maxOutputSize >= LZ4_compressBound(inputSize)) { + if (inputSize < LZ4_64Klimit) { + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, 0, notLimited, byU16, noDict, noDictIssue, acceleration); + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)source > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, 0, notLimited, tableType, noDict, noDictIssue, acceleration); + } + } else { + if (inputSize < LZ4_64Klimit) { + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, byU16, noDict, noDictIssue, acceleration); + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)source > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + return LZ4_compress_generic(ctx, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, noDict, noDictIssue, acceleration); + } + } +} + +/** + * LZ4_compress_fast_extState_fastReset() : + * A variant of LZ4_compress_fast_extState(). + * + * Using this variant avoids an expensive initialization step. It is only safe + * to call if the state buffer is known to be correctly initialized already + * (see comment in lz4.h on LZ4_resetStream_fast() for a definition of + * "correctly initialized"). + */ +int LZ4_compress_fast_extState_fastReset(void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration) +{ + LZ4_stream_t_internal* ctx = &((LZ4_stream_t*)state)->internal_donotuse; + if (acceleration < 1) acceleration = ACCELERATION_DEFAULT; + + if (dstCapacity >= LZ4_compressBound(srcSize)) { + if (srcSize < LZ4_64Klimit) { + const tableType_t tableType = byU16; + LZ4_prepareTable(ctx, srcSize, tableType); + if (ctx->currentOffset) { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, 0, notLimited, tableType, noDict, dictSmall, acceleration); + } else { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, 0, notLimited, tableType, noDict, noDictIssue, acceleration); + } + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)src > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + LZ4_prepareTable(ctx, srcSize, tableType); + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, 0, notLimited, tableType, noDict, noDictIssue, acceleration); + } + } else { + if (srcSize < LZ4_64Klimit) { + const tableType_t tableType = byU16; + LZ4_prepareTable(ctx, srcSize, tableType); + if (ctx->currentOffset) { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, dstCapacity, limitedOutput, tableType, noDict, dictSmall, acceleration); + } else { + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, dstCapacity, limitedOutput, tableType, noDict, noDictIssue, acceleration); + } + } else { + const tableType_t tableType = ((sizeof(void*)==4) && ((uptrval)src > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + LZ4_prepareTable(ctx, srcSize, tableType); + return LZ4_compress_generic(ctx, src, dst, srcSize, NULL, dstCapacity, limitedOutput, tableType, noDict, noDictIssue, acceleration); + } + } +} + + +int LZ4_compress_fast(const char* source, char* dest, int inputSize, int maxOutputSize, int acceleration) +{ + int result; +#if (LZ4_HEAPMODE) + LZ4_stream_t* ctxPtr = ALLOC(sizeof(LZ4_stream_t)); /* malloc-calloc always properly aligned */ + if (ctxPtr == NULL) return 0; +#else + LZ4_stream_t ctx; + LZ4_stream_t* const ctxPtr = &ctx; +#endif + result = LZ4_compress_fast_extState(ctxPtr, source, dest, inputSize, maxOutputSize, acceleration); + +#if (LZ4_HEAPMODE) + FREEMEM(ctxPtr); +#endif + return result; +} + + +int LZ4_compress_default(const char* src, char* dst, int srcSize, int maxOutputSize) +{ + return LZ4_compress_fast(src, dst, srcSize, maxOutputSize, 1); +} + + +/* hidden debug function */ +/* strangely enough, gcc generates faster code when this function is uncommented, even if unused */ +int LZ4_compress_fast_force(const char* src, char* dst, int srcSize, int dstCapacity, int acceleration) +{ + LZ4_stream_t ctx; + LZ4_initStream(&ctx, sizeof(ctx)); + + if (srcSize < LZ4_64Klimit) { + return LZ4_compress_generic(&ctx.internal_donotuse, src, dst, srcSize, NULL, dstCapacity, limitedOutput, byU16, noDict, noDictIssue, acceleration); + } else { + tableType_t const addrMode = (sizeof(void*) > 4) ? byU32 : byPtr; + return LZ4_compress_generic(&ctx.internal_donotuse, src, dst, srcSize, NULL, dstCapacity, limitedOutput, addrMode, noDict, noDictIssue, acceleration); + } +} + + +/* Note!: This function leaves the stream in an unclean/broken state! + * It is not safe to subsequently use the same state with a _fastReset() or + * _continue() call without resetting it. */ +static int LZ4_compress_destSize_extState (LZ4_stream_t* state, const char* src, char* dst, int* srcSizePtr, int targetDstSize) +{ + void* const s = LZ4_initStream(state, sizeof (*state)); + assert(s != NULL); (void)s; + + if (targetDstSize >= LZ4_compressBound(*srcSizePtr)) { /* compression success is guaranteed */ + return LZ4_compress_fast_extState(state, src, dst, *srcSizePtr, targetDstSize, 1); + } else { + if (*srcSizePtr < LZ4_64Klimit) { + return LZ4_compress_generic(&state->internal_donotuse, src, dst, *srcSizePtr, srcSizePtr, targetDstSize, fillOutput, byU16, noDict, noDictIssue, 1); + } else { + tableType_t const addrMode = ((sizeof(void*)==4) && ((uptrval)src > LZ4_DISTANCE_MAX)) ? byPtr : byU32; + return LZ4_compress_generic(&state->internal_donotuse, src, dst, *srcSizePtr, srcSizePtr, targetDstSize, fillOutput, addrMode, noDict, noDictIssue, 1); + } } +} + + +int LZ4_compress_destSize(const char* src, char* dst, int* srcSizePtr, int targetDstSize) +{ +#if (LZ4_HEAPMODE) + LZ4_stream_t* ctx = (LZ4_stream_t*)ALLOC(sizeof(LZ4_stream_t)); /* malloc-calloc always properly aligned */ + if (ctx == NULL) return 0; +#else + LZ4_stream_t ctxBody; + LZ4_stream_t* ctx = &ctxBody; +#endif + + int result = LZ4_compress_destSize_extState(ctx, src, dst, srcSizePtr, targetDstSize); + +#if (LZ4_HEAPMODE) + FREEMEM(ctx); +#endif + return result; +} + + + +/*-****************************** +* Streaming functions +********************************/ + +LZ4_stream_t* LZ4_createStream(void) +{ + LZ4_stream_t* const lz4s = (LZ4_stream_t*)ALLOC(sizeof(LZ4_stream_t)); + LZ4_STATIC_ASSERT(LZ4_STREAMSIZE >= sizeof(LZ4_stream_t_internal)); /* A compilation error here means LZ4_STREAMSIZE is not large enough */ + DEBUGLOG(4, "LZ4_createStream %p", lz4s); + if (lz4s == NULL) return NULL; + LZ4_initStream(lz4s, sizeof(*lz4s)); + return lz4s; +} + +#ifndef _MSC_VER /* for some reason, Visual fails the aligment test on 32-bit x86 : + it reports an aligment of 8-bytes, + while actually aligning LZ4_stream_t on 4 bytes. */ +static size_t LZ4_stream_t_alignment(void) +{ + struct { char c; LZ4_stream_t t; } t_a; + return sizeof(t_a) - sizeof(t_a.t); +} +#endif + +LZ4_stream_t* LZ4_initStream (void* buffer, size_t size) +{ + DEBUGLOG(5, "LZ4_initStream"); + if (buffer == NULL) { return NULL; } + if (size < sizeof(LZ4_stream_t)) { return NULL; } +#ifndef _MSC_VER /* for some reason, Visual fails the aligment test on 32-bit x86 : + it reports an aligment of 8-bytes, + while actually aligning LZ4_stream_t on 4 bytes. */ + if (((size_t)buffer) & (LZ4_stream_t_alignment() - 1)) { return NULL; } /* alignment check */ +#endif + MEM_INIT(buffer, 0, sizeof(LZ4_stream_t)); + return (LZ4_stream_t*)buffer; +} + +/* resetStream is now deprecated, + * prefer initStream() which is more general */ +void LZ4_resetStream (LZ4_stream_t* LZ4_stream) +{ + DEBUGLOG(5, "LZ4_resetStream (ctx:%p)", LZ4_stream); + MEM_INIT(LZ4_stream, 0, sizeof(LZ4_stream_t)); +} + +void LZ4_resetStream_fast(LZ4_stream_t* ctx) { + LZ4_prepareTable(&(ctx->internal_donotuse), 0, byU32); +} + +int LZ4_freeStream (LZ4_stream_t* LZ4_stream) +{ + if (!LZ4_stream) return 0; /* support free on NULL */ + DEBUGLOG(5, "LZ4_freeStream %p", LZ4_stream); + FREEMEM(LZ4_stream); + return (0); +} + + +#define HASH_UNIT sizeof(reg_t) +int LZ4_loadDict (LZ4_stream_t* LZ4_dict, const char* dictionary, int dictSize) +{ + LZ4_stream_t_internal* dict = &LZ4_dict->internal_donotuse; + const tableType_t tableType = byU32; + const BYTE* p = (const BYTE*)dictionary; + const BYTE* const dictEnd = p + dictSize; + const BYTE* base; + + DEBUGLOG(4, "LZ4_loadDict (%i bytes from %p into %p)", dictSize, dictionary, LZ4_dict); + + /* It's necessary to reset the context, + * and not just continue it with prepareTable() + * to avoid any risk of generating overflowing matchIndex + * when compressing using this dictionary */ + LZ4_resetStream(LZ4_dict); + + /* We always increment the offset by 64 KB, since, if the dict is longer, + * we truncate it to the last 64k, and if it's shorter, we still want to + * advance by a whole window length so we can provide the guarantee that + * there are only valid offsets in the window, which allows an optimization + * in LZ4_compress_fast_continue() where it uses noDictIssue even when the + * dictionary isn't a full 64k. */ + dict->currentOffset += 64 KB; + + if (dictSize < (int)HASH_UNIT) { + return 0; + } + + if ((dictEnd - p) > 64 KB) p = dictEnd - 64 KB; + base = dictEnd - dict->currentOffset; + dict->dictionary = p; + dict->dictSize = (U32)(dictEnd - p); + dict->tableType = tableType; + + while (p <= dictEnd-HASH_UNIT) { + LZ4_putPosition(p, dict->hashTable, tableType, base); + p+=3; + } + + return (int)dict->dictSize; +} + +void LZ4_attach_dictionary(LZ4_stream_t* workingStream, const LZ4_stream_t* dictionaryStream) { + const LZ4_stream_t_internal* dictCtx = dictionaryStream == NULL ? NULL : + &(dictionaryStream->internal_donotuse); + + DEBUGLOG(4, "LZ4_attach_dictionary (%p, %p, size %u)", + workingStream, dictionaryStream, + dictCtx != NULL ? dictCtx->dictSize : 0); + + /* Calling LZ4_resetStream_fast() here makes sure that changes will not be + * erased by subsequent calls to LZ4_resetStream_fast() in case stream was + * marked as having dirty context, e.g. requiring full reset. + */ + LZ4_resetStream_fast(workingStream); + + if (dictCtx != NULL) { + /* If the current offset is zero, we will never look in the + * external dictionary context, since there is no value a table + * entry can take that indicate a miss. In that case, we need + * to bump the offset to something non-zero. + */ + if (workingStream->internal_donotuse.currentOffset == 0) { + workingStream->internal_donotuse.currentOffset = 64 KB; + } + + /* Don't actually attach an empty dictionary. + */ + if (dictCtx->dictSize == 0) { + dictCtx = NULL; + } + } + workingStream->internal_donotuse.dictCtx = dictCtx; +} + + +static void LZ4_renormDictT(LZ4_stream_t_internal* LZ4_dict, int nextSize) +{ + assert(nextSize >= 0); + if (LZ4_dict->currentOffset + (unsigned)nextSize > 0x80000000) { /* potential ptrdiff_t overflow (32-bits mode) */ + /* rescale hash table */ + U32 const delta = LZ4_dict->currentOffset - 64 KB; + const BYTE* dictEnd = LZ4_dict->dictionary + LZ4_dict->dictSize; + int i; + DEBUGLOG(4, "LZ4_renormDictT"); + for (i=0; ihashTable[i] < delta) LZ4_dict->hashTable[i]=0; + else LZ4_dict->hashTable[i] -= delta; + } + LZ4_dict->currentOffset = 64 KB; + if (LZ4_dict->dictSize > 64 KB) LZ4_dict->dictSize = 64 KB; + LZ4_dict->dictionary = dictEnd - LZ4_dict->dictSize; + } +} + + +int LZ4_compress_fast_continue (LZ4_stream_t* LZ4_stream, + const char* source, char* dest, + int inputSize, int maxOutputSize, + int acceleration) +{ + const tableType_t tableType = byU32; + LZ4_stream_t_internal* streamPtr = &LZ4_stream->internal_donotuse; + const BYTE* dictEnd = streamPtr->dictionary + streamPtr->dictSize; + + DEBUGLOG(5, "LZ4_compress_fast_continue (inputSize=%i)", inputSize); + + if (streamPtr->dirty) { return 0; } /* Uninitialized structure detected */ + LZ4_renormDictT(streamPtr, inputSize); /* avoid index overflow */ + if (acceleration < 1) acceleration = ACCELERATION_DEFAULT; + + /* invalidate tiny dictionaries */ + if ( (streamPtr->dictSize-1 < 4-1) /* intentional underflow */ + && (dictEnd != (const BYTE*)source) ) { + DEBUGLOG(5, "LZ4_compress_fast_continue: dictSize(%u) at addr:%p is too small", streamPtr->dictSize, streamPtr->dictionary); + streamPtr->dictSize = 0; + streamPtr->dictionary = (const BYTE*)source; + dictEnd = (const BYTE*)source; + } + + /* Check overlapping input/dictionary space */ + { const BYTE* sourceEnd = (const BYTE*) source + inputSize; + if ((sourceEnd > streamPtr->dictionary) && (sourceEnd < dictEnd)) { + streamPtr->dictSize = (U32)(dictEnd - sourceEnd); + if (streamPtr->dictSize > 64 KB) streamPtr->dictSize = 64 KB; + if (streamPtr->dictSize < 4) streamPtr->dictSize = 0; + streamPtr->dictionary = dictEnd - streamPtr->dictSize; + } + } + + /* prefix mode : source data follows dictionary */ + if (dictEnd == (const BYTE*)source) { + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) + return LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, withPrefix64k, dictSmall, acceleration); + else + return LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, withPrefix64k, noDictIssue, acceleration); + } + + /* external dictionary mode */ + { int result; + if (streamPtr->dictCtx) { + /* We depend here on the fact that dictCtx'es (produced by + * LZ4_loadDict) guarantee that their tables contain no references + * to offsets between dictCtx->currentOffset - 64 KB and + * dictCtx->currentOffset - dictCtx->dictSize. This makes it safe + * to use noDictIssue even when the dict isn't a full 64 KB. + */ + if (inputSize > 4 KB) { + /* For compressing large blobs, it is faster to pay the setup + * cost to copy the dictionary's tables into the active context, + * so that the compression loop is only looking into one table. + */ + memcpy(streamPtr, streamPtr->dictCtx, sizeof(LZ4_stream_t)); + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingExtDict, noDictIssue, acceleration); + } else { + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingDictCtx, noDictIssue, acceleration); + } + } else { + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) { + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingExtDict, dictSmall, acceleration); + } else { + result = LZ4_compress_generic(streamPtr, source, dest, inputSize, NULL, maxOutputSize, limitedOutput, tableType, usingExtDict, noDictIssue, acceleration); + } + } + streamPtr->dictionary = (const BYTE*)source; + streamPtr->dictSize = (U32)inputSize; + return result; + } +} + + +/* Hidden debug function, to force-test external dictionary mode */ +int LZ4_compress_forceExtDict (LZ4_stream_t* LZ4_dict, const char* source, char* dest, int srcSize) +{ + LZ4_stream_t_internal* streamPtr = &LZ4_dict->internal_donotuse; + int result; + + LZ4_renormDictT(streamPtr, srcSize); + + if ((streamPtr->dictSize < 64 KB) && (streamPtr->dictSize < streamPtr->currentOffset)) { + result = LZ4_compress_generic(streamPtr, source, dest, srcSize, NULL, 0, notLimited, byU32, usingExtDict, dictSmall, 1); + } else { + result = LZ4_compress_generic(streamPtr, source, dest, srcSize, NULL, 0, notLimited, byU32, usingExtDict, noDictIssue, 1); + } + + streamPtr->dictionary = (const BYTE*)source; + streamPtr->dictSize = (U32)srcSize; + + return result; +} + + +/*! LZ4_saveDict() : + * If previously compressed data block is not guaranteed to remain available at its memory location, + * save it into a safer place (char* safeBuffer). + * Note : you don't need to call LZ4_loadDict() afterwards, + * dictionary is immediately usable, you can therefore call LZ4_compress_fast_continue(). + * Return : saved dictionary size in bytes (necessarily <= dictSize), or 0 if error. + */ +int LZ4_saveDict (LZ4_stream_t* LZ4_dict, char* safeBuffer, int dictSize) +{ + LZ4_stream_t_internal* const dict = &LZ4_dict->internal_donotuse; + const BYTE* const previousDictEnd = dict->dictionary + dict->dictSize; + + if ((U32)dictSize > 64 KB) { dictSize = 64 KB; } /* useless to define a dictionary > 64 KB */ + if ((U32)dictSize > dict->dictSize) { dictSize = (int)dict->dictSize; } + + memmove(safeBuffer, previousDictEnd - dictSize, dictSize); + + dict->dictionary = (const BYTE*)safeBuffer; + dict->dictSize = (U32)dictSize; + + return dictSize; +} + + + +/*-******************************* + * Decompression functions + ********************************/ + +typedef enum { endOnOutputSize = 0, endOnInputSize = 1 } endCondition_directive; +typedef enum { decode_full_block = 0, partial_decode = 1 } earlyEnd_directive; + +#undef MIN +#define MIN(a,b) ( (a) < (b) ? (a) : (b) ) + +/* Read the variable-length literal or match length. + * + * ip - pointer to use as input. + * lencheck - end ip. Return an error if ip advances >= lencheck. + * loop_check - check ip >= lencheck in body of loop. Returns loop_error if so. + * initial_check - check ip >= lencheck before start of loop. Returns initial_error if so. + * error (output) - error code. Should be set to 0 before call. + */ +typedef enum { loop_error = -2, initial_error = -1, ok = 0 } variable_length_error; +LZ4_FORCE_INLINE unsigned +read_variable_length(const BYTE**ip, const BYTE* lencheck, int loop_check, int initial_check, variable_length_error* error) +{ + U32 length = 0; + U32 s; + if (initial_check && unlikely((*ip) >= lencheck)) { /* overflow detection */ + *error = initial_error; + return length; + } + do { + s = **ip; + (*ip)++; + length += s; + if (loop_check && unlikely((*ip) >= lencheck)) { /* overflow detection */ + *error = loop_error; + return length; + } + } while (s==255); + + return length; +} + +/*! LZ4_decompress_generic() : + * This generic decompression function covers all use cases. + * It shall be instantiated several times, using different sets of directives. + * Note that it is important for performance that this function really get inlined, + * in order to remove useless branches during compilation optimization. + */ +LZ4_FORCE_INLINE int +LZ4_decompress_generic( + const char* const src, + char* const dst, + int srcSize, + int outputSize, /* If endOnInput==endOnInputSize, this value is `dstCapacity` */ + + endCondition_directive endOnInput, /* endOnOutputSize, endOnInputSize */ + earlyEnd_directive partialDecoding, /* full, partial */ + dict_directive dict, /* noDict, withPrefix64k, usingExtDict */ + const BYTE* const lowPrefix, /* always <= dst, == dst when no prefix */ + const BYTE* const dictStart, /* only if dict==usingExtDict */ + const size_t dictSize /* note : = 0 if noDict */ + ) +{ + if (src == NULL) { return -1; } + + { const BYTE* ip = (const BYTE*) src; + const BYTE* const iend = ip + srcSize; + + BYTE* op = (BYTE*) dst; + BYTE* const oend = op + outputSize; + BYTE* cpy; + + const BYTE* const dictEnd = (dictStart == NULL) ? NULL : dictStart + dictSize; + + const int safeDecode = (endOnInput==endOnInputSize); + const int checkOffset = ((safeDecode) && (dictSize < (int)(64 KB))); + + + /* Set up the "end" pointers for the shortcut. */ + const BYTE* const shortiend = iend - (endOnInput ? 14 : 8) /*maxLL*/ - 2 /*offset*/; + const BYTE* const shortoend = oend - (endOnInput ? 14 : 8) /*maxLL*/ - 18 /*maxML*/; + + const BYTE* match; + size_t offset; + unsigned token; + size_t length; + + + DEBUGLOG(5, "LZ4_decompress_generic (srcSize:%i, dstSize:%i)", srcSize, outputSize); + + /* Special cases */ + assert(lowPrefix <= op); + if ((endOnInput) && (unlikely(outputSize==0))) { + /* Empty output buffer */ + if (partialDecoding) return 0; + return ((srcSize==1) && (*ip==0)) ? 0 : -1; + } + if ((!endOnInput) && (unlikely(outputSize==0))) { return (*ip==0 ? 1 : -1); } + if ((endOnInput) && unlikely(srcSize==0)) { return -1; } + + /* Currently the fast loop shows a regression on qualcomm arm chips. */ +#if LZ4_FAST_DEC_LOOP + if ((oend - op) < FASTLOOP_SAFE_DISTANCE) { + DEBUGLOG(6, "skip fast decode loop"); + goto safe_decode; + } + + /* Fast loop : decode sequences as long as output < iend-FASTLOOP_SAFE_DISTANCE */ + while (1) { + /* Main fastloop assertion: We can always wildcopy FASTLOOP_SAFE_DISTANCE */ + assert(oend - op >= FASTLOOP_SAFE_DISTANCE); + if (endOnInput) { assert(ip < iend); } + token = *ip++; + length = token >> ML_BITS; /* literal length */ + + assert(!endOnInput || ip <= iend); /* ip < iend before the increment */ + + /* decode literal length */ + if (length == RUN_MASK) { + variable_length_error error = ok; + length += read_variable_length(&ip, iend-RUN_MASK, endOnInput, endOnInput, &error); + if (error == initial_error) { goto _output_error; } + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)(op))) { goto _output_error; } /* overflow detection */ + if ((safeDecode) && unlikely((uptrval)(ip)+length<(uptrval)(ip))) { goto _output_error; } /* overflow detection */ + + /* copy literals */ + cpy = op+length; + LZ4_STATIC_ASSERT(MFLIMIT >= WILDCOPYLENGTH); + if (endOnInput) { /* LZ4_decompress_safe() */ + if ((cpy>oend-32) || (ip+length>iend-32)) { goto safe_literal_copy; } + LZ4_wildCopy32(op, ip, cpy); + } else { /* LZ4_decompress_fast() */ + if (cpy>oend-8) { goto safe_literal_copy; } + LZ4_wildCopy8(op, ip, cpy); /* LZ4_decompress_fast() cannot copy more than 8 bytes at a time : + * it doesn't know input length, and only relies on end-of-block properties */ + } + ip += length; op = cpy; + } else { + cpy = op+length; + if (endOnInput) { /* LZ4_decompress_safe() */ + DEBUGLOG(7, "copy %u bytes in a 16-bytes stripe", (unsigned)length); + /* We don't need to check oend, since we check it once for each loop below */ + if (ip > iend-(16 + 1/*max lit + offset + nextToken*/)) { goto safe_literal_copy; } + /* Literals can only be 14, but hope compilers optimize if we copy by a register size */ + memcpy(op, ip, 16); + } else { /* LZ4_decompress_fast() */ + /* LZ4_decompress_fast() cannot copy more than 8 bytes at a time : + * it doesn't know input length, and relies on end-of-block properties */ + memcpy(op, ip, 8); + if (length > 8) { memcpy(op+8, ip+8, 8); } + } + ip += length; op = cpy; + } + + /* get offset */ + offset = LZ4_readLE16(ip); ip+=2; + match = op - offset; + assert(match <= op); + + /* get matchlength */ + length = token & ML_MASK; + + if (length == ML_MASK) { + variable_length_error error = ok; + if ((checkOffset) && (unlikely(match + dictSize < lowPrefix))) { goto _output_error; } /* Error : offset outside buffers */ + length += read_variable_length(&ip, iend - LASTLITERALS + 1, endOnInput, 0, &error); + if (error != ok) { goto _output_error; } + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)op)) { goto _output_error; } /* overflow detection */ + length += MINMATCH; + if (op + length >= oend - FASTLOOP_SAFE_DISTANCE) { + goto safe_match_copy; + } + } else { + length += MINMATCH; + if (op + length >= oend - FASTLOOP_SAFE_DISTANCE) { + goto safe_match_copy; + } + + /* Fastpath check: Avoids a branch in LZ4_wildCopy32 if true */ + if ((dict == withPrefix64k) || (match >= lowPrefix)) { + if (offset >= 8) { + assert(match >= lowPrefix); + assert(match <= op); + assert(op + 18 <= oend); + + memcpy(op, match, 8); + memcpy(op+8, match+8, 8); + memcpy(op+16, match+16, 2); + op += length; + continue; + } } } + + if ((checkOffset) && (unlikely(match + dictSize < lowPrefix))) { goto _output_error; } /* Error : offset outside buffers */ + /* match starting within external dictionary */ + if ((dict==usingExtDict) && (match < lowPrefix)) { + if (unlikely(op+length > oend-LASTLITERALS)) { + if (partialDecoding) { + length = MIN(length, (size_t)(oend-op)); /* reach end of buffer */ + } else { + goto _output_error; /* end-of-block condition violated */ + } } + + if (length <= (size_t)(lowPrefix-match)) { + /* match fits entirely within external dictionary : just copy */ + memmove(op, dictEnd - (lowPrefix-match), length); + op += length; + } else { + /* match stretches into both external dictionary and current block */ + size_t const copySize = (size_t)(lowPrefix - match); + size_t const restSize = length - copySize; + memcpy(op, dictEnd - copySize, copySize); + op += copySize; + if (restSize > (size_t)(op - lowPrefix)) { /* overlap copy */ + BYTE* const endOfMatch = op + restSize; + const BYTE* copyFrom = lowPrefix; + while (op < endOfMatch) { *op++ = *copyFrom++; } + } else { + memcpy(op, lowPrefix, restSize); + op += restSize; + } } + continue; + } + + /* copy match within block */ + cpy = op + length; + + assert((op <= oend) && (oend-op >= 32)); + if (unlikely(offset<16)) { + LZ4_memcpy_using_offset(op, match, cpy, offset); + } else { + LZ4_wildCopy32(op, match, cpy); + } + + op = cpy; /* wildcopy correction */ + } + safe_decode: +#endif + + /* Main Loop : decode remaining sequences where output < FASTLOOP_SAFE_DISTANCE */ + while (1) { + token = *ip++; + length = token >> ML_BITS; /* literal length */ + + assert(!endOnInput || ip <= iend); /* ip < iend before the increment */ + + /* A two-stage shortcut for the most common case: + * 1) If the literal length is 0..14, and there is enough space, + * enter the shortcut and copy 16 bytes on behalf of the literals + * (in the fast mode, only 8 bytes can be safely copied this way). + * 2) Further if the match length is 4..18, copy 18 bytes in a similar + * manner; but we ensure that there's enough space in the output for + * those 18 bytes earlier, upon entering the shortcut (in other words, + * there is a combined check for both stages). + */ + if ( (endOnInput ? length != RUN_MASK : length <= 8) + /* strictly "less than" on input, to re-enter the loop with at least one byte */ + && likely((endOnInput ? ip < shortiend : 1) & (op <= shortoend)) ) { + /* Copy the literals */ + memcpy(op, ip, endOnInput ? 16 : 8); + op += length; ip += length; + + /* The second stage: prepare for match copying, decode full info. + * If it doesn't work out, the info won't be wasted. */ + length = token & ML_MASK; /* match length */ + offset = LZ4_readLE16(ip); ip += 2; + match = op - offset; + assert(match <= op); /* check overflow */ + + /* Do not deal with overlapping matches. */ + if ( (length != ML_MASK) + && (offset >= 8) + && (dict==withPrefix64k || match >= lowPrefix) ) { + /* Copy the match. */ + memcpy(op + 0, match + 0, 8); + memcpy(op + 8, match + 8, 8); + memcpy(op +16, match +16, 2); + op += length + MINMATCH; + /* Both stages worked, load the next token. */ + continue; + } + + /* The second stage didn't work out, but the info is ready. + * Propel it right to the point of match copying. */ + goto _copy_match; + } + + /* decode literal length */ + if (length == RUN_MASK) { + variable_length_error error = ok; + length += read_variable_length(&ip, iend-RUN_MASK, endOnInput, endOnInput, &error); + if (error == initial_error) { goto _output_error; } + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)(op))) { goto _output_error; } /* overflow detection */ + if ((safeDecode) && unlikely((uptrval)(ip)+length<(uptrval)(ip))) { goto _output_error; } /* overflow detection */ + } + + /* copy literals */ + cpy = op+length; +#if LZ4_FAST_DEC_LOOP + safe_literal_copy: +#endif + LZ4_STATIC_ASSERT(MFLIMIT >= WILDCOPYLENGTH); + if ( ((endOnInput) && ((cpy>oend-MFLIMIT) || (ip+length>iend-(2+1+LASTLITERALS))) ) + || ((!endOnInput) && (cpy>oend-WILDCOPYLENGTH)) ) + { + /* We've either hit the input parsing restriction or the output parsing restriction. + * If we've hit the input parsing condition then this must be the last sequence. + * If we've hit the output parsing condition then we are either using partialDecoding + * or we've hit the output parsing condition. + */ + if (partialDecoding) { + /* Since we are partial decoding we may be in this block because of the output parsing + * restriction, which is not valid since the output buffer is allowed to be undersized. + */ + assert(endOnInput); + /* If we're in this block because of the input parsing condition, then we must be on the + * last sequence (or invalid), so we must check that we exactly consume the input. + */ + if ((ip+length>iend-(2+1+LASTLITERALS)) && (ip+length != iend)) { goto _output_error; } + assert(ip+length <= iend); + /* We are finishing in the middle of a literals segment. + * Break after the copy. + */ + if (cpy > oend) { + cpy = oend; + assert(op<=oend); + length = (size_t)(oend-op); + } + assert(ip+length <= iend); + } else { + /* We must be on the last sequence because of the parsing limitations so check + * that we exactly regenerate the original size (must be exact when !endOnInput). + */ + if ((!endOnInput) && (cpy != oend)) { goto _output_error; } + /* We must be on the last sequence (or invalid) because of the parsing limitations + * so check that we exactly consume the input and don't overrun the output buffer. + */ + if ((endOnInput) && ((ip+length != iend) || (cpy > oend))) { goto _output_error; } + } + memmove(op, ip, length); /* supports overlapping memory regions, which only matters for in-place decompression scenarios */ + ip += length; + op += length; + /* Necessarily EOF when !partialDecoding. When partialDecoding + * it is EOF if we've either filled the output buffer or hit + * the input parsing restriction. + */ + if (!partialDecoding || (cpy == oend) || (ip == iend)) { + break; + } + } else { + LZ4_wildCopy8(op, ip, cpy); /* may overwrite up to WILDCOPYLENGTH beyond cpy */ + ip += length; op = cpy; + } + + /* get offset */ + offset = LZ4_readLE16(ip); ip+=2; + match = op - offset; + + /* get matchlength */ + length = token & ML_MASK; + + _copy_match: + if (length == ML_MASK) { + variable_length_error error = ok; + length += read_variable_length(&ip, iend - LASTLITERALS + 1, endOnInput, 0, &error); + if (error != ok) goto _output_error; + if ((safeDecode) && unlikely((uptrval)(op)+length<(uptrval)op)) goto _output_error; /* overflow detection */ + } + length += MINMATCH; + +#if LZ4_FAST_DEC_LOOP + safe_match_copy: +#endif + if ((checkOffset) && (unlikely(match + dictSize < lowPrefix))) goto _output_error; /* Error : offset outside buffers */ + /* match starting within external dictionary */ + if ((dict==usingExtDict) && (match < lowPrefix)) { + if (unlikely(op+length > oend-LASTLITERALS)) { + if (partialDecoding) length = MIN(length, (size_t)(oend-op)); + else goto _output_error; /* doesn't respect parsing restriction */ + } + + if (length <= (size_t)(lowPrefix-match)) { + /* match fits entirely within external dictionary : just copy */ + memmove(op, dictEnd - (lowPrefix-match), length); + op += length; + } else { + /* match stretches into both external dictionary and current block */ + size_t const copySize = (size_t)(lowPrefix - match); + size_t const restSize = length - copySize; + memcpy(op, dictEnd - copySize, copySize); + op += copySize; + if (restSize > (size_t)(op - lowPrefix)) { /* overlap copy */ + BYTE* const endOfMatch = op + restSize; + const BYTE* copyFrom = lowPrefix; + while (op < endOfMatch) *op++ = *copyFrom++; + } else { + memcpy(op, lowPrefix, restSize); + op += restSize; + } } + continue; + } + assert(match >= lowPrefix); + + /* copy match within block */ + cpy = op + length; + + /* partialDecoding : may end anywhere within the block */ + assert(op<=oend); + if (partialDecoding && (cpy > oend-MATCH_SAFEGUARD_DISTANCE)) { + size_t const mlen = MIN(length, (size_t)(oend-op)); + const BYTE* const matchEnd = match + mlen; + BYTE* const copyEnd = op + mlen; + if (matchEnd > op) { /* overlap copy */ + while (op < copyEnd) { *op++ = *match++; } + } else { + memcpy(op, match, mlen); + } + op = copyEnd; + if (op == oend) { break; } + continue; + } + + if (unlikely(offset<8)) { + LZ4_write32(op, 0); /* silence msan warning when offset==0 */ + op[0] = match[0]; + op[1] = match[1]; + op[2] = match[2]; + op[3] = match[3]; + match += inc32table[offset]; + memcpy(op+4, match, 4); + match -= dec64table[offset]; + } else { + memcpy(op, match, 8); + match += 8; + } + op += 8; + + if (unlikely(cpy > oend-MATCH_SAFEGUARD_DISTANCE)) { + BYTE* const oCopyLimit = oend - (WILDCOPYLENGTH-1); + if (cpy > oend-LASTLITERALS) { goto _output_error; } /* Error : last LASTLITERALS bytes must be literals (uncompressed) */ + if (op < oCopyLimit) { + LZ4_wildCopy8(op, match, oCopyLimit); + match += oCopyLimit - op; + op = oCopyLimit; + } + while (op < cpy) { *op++ = *match++; } + } else { + memcpy(op, match, 8); + if (length > 16) { LZ4_wildCopy8(op+8, match+8, cpy); } + } + op = cpy; /* wildcopy correction */ + } + + /* end of decoding */ + if (endOnInput) { + return (int) (((char*)op)-dst); /* Nb of output bytes decoded */ + } else { + return (int) (((const char*)ip)-src); /* Nb of input bytes read */ + } + + /* Overflow error detected */ + _output_error: + return (int) (-(((const char*)ip)-src))-1; + } +} + + +/*===== Instantiate the API decoding functions. =====*/ + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe(const char* source, char* dest, int compressedSize, int maxDecompressedSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxDecompressedSize, + endOnInputSize, decode_full_block, noDict, + (BYTE*)dest, NULL, 0); +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe_partial(const char* src, char* dst, int compressedSize, int targetOutputSize, int dstCapacity) +{ + dstCapacity = MIN(targetOutputSize, dstCapacity); + return LZ4_decompress_generic(src, dst, compressedSize, dstCapacity, + endOnInputSize, partial_decode, + noDict, (BYTE*)dst, NULL, 0); +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_fast(const char* source, char* dest, int originalSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, decode_full_block, withPrefix64k, + (BYTE*)dest - 64 KB, NULL, 0); +} + +/*===== Instantiate a few more decoding cases, used more than once. =====*/ + +LZ4_FORCE_O2_GCC_PPC64LE /* Exported, an obsolete API function. */ +int LZ4_decompress_safe_withPrefix64k(const char* source, char* dest, int compressedSize, int maxOutputSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, withPrefix64k, + (BYTE*)dest - 64 KB, NULL, 0); +} + +/* Another obsolete API function, paired with the previous one. */ +int LZ4_decompress_fast_withPrefix64k(const char* source, char* dest, int originalSize) +{ + /* LZ4_decompress_fast doesn't validate match offsets, + * and thus serves well with any prefixed dictionary. */ + return LZ4_decompress_fast(source, dest, originalSize); +} + +LZ4_FORCE_O2_GCC_PPC64LE +static int LZ4_decompress_safe_withSmallPrefix(const char* source, char* dest, int compressedSize, int maxOutputSize, + size_t prefixSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, noDict, + (BYTE*)dest-prefixSize, NULL, 0); +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe_forceExtDict(const char* source, char* dest, + int compressedSize, int maxOutputSize, + const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, usingExtDict, + (BYTE*)dest, (const BYTE*)dictStart, dictSize); +} + +LZ4_FORCE_O2_GCC_PPC64LE +static int LZ4_decompress_fast_extDict(const char* source, char* dest, int originalSize, + const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, decode_full_block, usingExtDict, + (BYTE*)dest, (const BYTE*)dictStart, dictSize); +} + +/* The "double dictionary" mode, for use with e.g. ring buffers: the first part + * of the dictionary is passed as prefix, and the second via dictStart + dictSize. + * These routines are used only once, in LZ4_decompress_*_continue(). + */ +LZ4_FORCE_INLINE +int LZ4_decompress_safe_doubleDict(const char* source, char* dest, int compressedSize, int maxOutputSize, + size_t prefixSize, const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, compressedSize, maxOutputSize, + endOnInputSize, decode_full_block, usingExtDict, + (BYTE*)dest-prefixSize, (const BYTE*)dictStart, dictSize); +} + +LZ4_FORCE_INLINE +int LZ4_decompress_fast_doubleDict(const char* source, char* dest, int originalSize, + size_t prefixSize, const void* dictStart, size_t dictSize) +{ + return LZ4_decompress_generic(source, dest, 0, originalSize, + endOnOutputSize, decode_full_block, usingExtDict, + (BYTE*)dest-prefixSize, (const BYTE*)dictStart, dictSize); +} + +/*===== streaming decompression functions =====*/ + +LZ4_streamDecode_t* LZ4_createStreamDecode(void) +{ + LZ4_streamDecode_t* lz4s = (LZ4_streamDecode_t*) ALLOC_AND_ZERO(sizeof(LZ4_streamDecode_t)); + LZ4_STATIC_ASSERT(LZ4_STREAMDECODESIZE >= sizeof(LZ4_streamDecode_t_internal)); /* A compilation error here means LZ4_STREAMDECODESIZE is not large enough */ + return lz4s; +} + +int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream) +{ + if (LZ4_stream == NULL) { return 0; } /* support free on NULL */ + FREEMEM(LZ4_stream); + return 0; +} + +/*! LZ4_setStreamDecode() : + * Use this function to instruct where to find the dictionary. + * This function is not necessary if previous data is still available where it was decoded. + * Loading a size of 0 is allowed (same effect as no dictionary). + * @return : 1 if OK, 0 if error + */ +int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize) +{ + LZ4_streamDecode_t_internal* lz4sd = &LZ4_streamDecode->internal_donotuse; + lz4sd->prefixSize = (size_t) dictSize; + lz4sd->prefixEnd = (const BYTE*) dictionary + dictSize; + lz4sd->externalDict = NULL; + lz4sd->extDictSize = 0; + return 1; +} + +/*! LZ4_decoderRingBufferSize() : + * when setting a ring buffer for streaming decompression (optional scenario), + * provides the minimum size of this ring buffer + * to be compatible with any source respecting maxBlockSize condition. + * Note : in a ring buffer scenario, + * blocks are presumed decompressed next to each other. + * When not enough space remains for next block (remainingSize < maxBlockSize), + * decoding resumes from beginning of ring buffer. + * @return : minimum ring buffer size, + * or 0 if there is an error (invalid maxBlockSize). + */ +int LZ4_decoderRingBufferSize(int maxBlockSize) +{ + if (maxBlockSize < 0) return 0; + if (maxBlockSize > LZ4_MAX_INPUT_SIZE) return 0; + if (maxBlockSize < 16) maxBlockSize = 16; + return LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize); +} + +/* +*_continue() : + These decoding functions allow decompression of multiple blocks in "streaming" mode. + Previously decoded blocks must still be available at the memory position where they were decoded. + If it's not possible, save the relevant part of decoded data into a safe buffer, + and indicate where it stands using LZ4_setStreamDecode() +*/ +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int compressedSize, int maxOutputSize) +{ + LZ4_streamDecode_t_internal* lz4sd = &LZ4_streamDecode->internal_donotuse; + int result; + + if (lz4sd->prefixSize == 0) { + /* The first call, no dictionary yet. */ + assert(lz4sd->extDictSize == 0); + result = LZ4_decompress_safe(source, dest, compressedSize, maxOutputSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)result; + lz4sd->prefixEnd = (BYTE*)dest + result; + } else if (lz4sd->prefixEnd == (BYTE*)dest) { + /* They're rolling the current segment. */ + if (lz4sd->prefixSize >= 64 KB - 1) + result = LZ4_decompress_safe_withPrefix64k(source, dest, compressedSize, maxOutputSize); + else if (lz4sd->extDictSize == 0) + result = LZ4_decompress_safe_withSmallPrefix(source, dest, compressedSize, maxOutputSize, + lz4sd->prefixSize); + else + result = LZ4_decompress_safe_doubleDict(source, dest, compressedSize, maxOutputSize, + lz4sd->prefixSize, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize += (size_t)result; + lz4sd->prefixEnd += result; + } else { + /* The buffer wraps around, or they're switching to another buffer. */ + lz4sd->extDictSize = lz4sd->prefixSize; + lz4sd->externalDict = lz4sd->prefixEnd - lz4sd->extDictSize; + result = LZ4_decompress_safe_forceExtDict(source, dest, compressedSize, maxOutputSize, + lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)result; + lz4sd->prefixEnd = (BYTE*)dest + result; + } + + return result; +} + +LZ4_FORCE_O2_GCC_PPC64LE +int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* source, char* dest, int originalSize) +{ + LZ4_streamDecode_t_internal* lz4sd = &LZ4_streamDecode->internal_donotuse; + int result; + assert(originalSize >= 0); + + if (lz4sd->prefixSize == 0) { + assert(lz4sd->extDictSize == 0); + result = LZ4_decompress_fast(source, dest, originalSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)originalSize; + lz4sd->prefixEnd = (BYTE*)dest + originalSize; + } else if (lz4sd->prefixEnd == (BYTE*)dest) { + if (lz4sd->prefixSize >= 64 KB - 1 || lz4sd->extDictSize == 0) + result = LZ4_decompress_fast(source, dest, originalSize); + else + result = LZ4_decompress_fast_doubleDict(source, dest, originalSize, + lz4sd->prefixSize, lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize += (size_t)originalSize; + lz4sd->prefixEnd += originalSize; + } else { + lz4sd->extDictSize = lz4sd->prefixSize; + lz4sd->externalDict = lz4sd->prefixEnd - lz4sd->extDictSize; + result = LZ4_decompress_fast_extDict(source, dest, originalSize, + lz4sd->externalDict, lz4sd->extDictSize); + if (result <= 0) return result; + lz4sd->prefixSize = (size_t)originalSize; + lz4sd->prefixEnd = (BYTE*)dest + originalSize; + } + + return result; +} + + +/* +Advanced decoding functions : +*_usingDict() : + These decoding functions work the same as "_continue" ones, + the dictionary must be explicitly provided within parameters +*/ + +int LZ4_decompress_safe_usingDict(const char* source, char* dest, int compressedSize, int maxOutputSize, const char* dictStart, int dictSize) +{ + if (dictSize==0) + return LZ4_decompress_safe(source, dest, compressedSize, maxOutputSize); + if (dictStart+dictSize == dest) { + if (dictSize >= 64 KB - 1) { + return LZ4_decompress_safe_withPrefix64k(source, dest, compressedSize, maxOutputSize); + } + assert(dictSize >= 0); + return LZ4_decompress_safe_withSmallPrefix(source, dest, compressedSize, maxOutputSize, (size_t)dictSize); + } + assert(dictSize >= 0); + return LZ4_decompress_safe_forceExtDict(source, dest, compressedSize, maxOutputSize, dictStart, (size_t)dictSize); +} + +int LZ4_decompress_fast_usingDict(const char* source, char* dest, int originalSize, const char* dictStart, int dictSize) +{ + if (dictSize==0 || dictStart+dictSize == dest) + return LZ4_decompress_fast(source, dest, originalSize); + assert(dictSize >= 0); + return LZ4_decompress_fast_extDict(source, dest, originalSize, dictStart, (size_t)dictSize); +} + + +/*=************************************************* +* Obsolete Functions +***************************************************/ +/* obsolete compression functions */ +int LZ4_compress_limitedOutput(const char* source, char* dest, int inputSize, int maxOutputSize) +{ + return LZ4_compress_default(source, dest, inputSize, maxOutputSize); +} +int LZ4_compress(const char* src, char* dest, int srcSize) +{ + return LZ4_compress_default(src, dest, srcSize, LZ4_compressBound(srcSize)); +} +int LZ4_compress_limitedOutput_withState (void* state, const char* src, char* dst, int srcSize, int dstSize) +{ + return LZ4_compress_fast_extState(state, src, dst, srcSize, dstSize, 1); +} +int LZ4_compress_withState (void* state, const char* src, char* dst, int srcSize) +{ + return LZ4_compress_fast_extState(state, src, dst, srcSize, LZ4_compressBound(srcSize), 1); +} +int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_stream, const char* src, char* dst, int srcSize, int dstCapacity) +{ + return LZ4_compress_fast_continue(LZ4_stream, src, dst, srcSize, dstCapacity, 1); +} +int LZ4_compress_continue (LZ4_stream_t* LZ4_stream, const char* source, char* dest, int inputSize) +{ + return LZ4_compress_fast_continue(LZ4_stream, source, dest, inputSize, LZ4_compressBound(inputSize), 1); +} + +/* +These decompression functions are deprecated and should no longer be used. +They are only provided here for compatibility with older user programs. +- LZ4_uncompress is totally equivalent to LZ4_decompress_fast +- LZ4_uncompress_unknownOutputSize is totally equivalent to LZ4_decompress_safe +*/ +int LZ4_uncompress (const char* source, char* dest, int outputSize) +{ + return LZ4_decompress_fast(source, dest, outputSize); +} +int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize) +{ + return LZ4_decompress_safe(source, dest, isize, maxOutputSize); +} + +/* Obsolete Streaming functions */ + +int LZ4_sizeofStreamState() { return LZ4_STREAMSIZE; } + +int LZ4_resetStreamState(void* state, char* inputBuffer) +{ + (void)inputBuffer; + LZ4_resetStream((LZ4_stream_t*)state); + return 0; +} + +void* LZ4_create (char* inputBuffer) +{ + (void)inputBuffer; + return LZ4_createStream(); +} + +char* LZ4_slideInputBuffer (void* state) +{ + /* avoid const char * -> char * conversion warning */ + return (char *)(uptrval)((LZ4_stream_t*)state)->internal_donotuse.dictionary; +} + +#endif /* LZ4_COMMONDEFS_ONLY */ diff --git a/src/inference/dataset/lz4.h b/src/inference/dataset/lz4.h new file mode 100644 index 0000000000000000000000000000000000000000..1e06bc65f13abdf6e7fdd097663aa52b2584bcb1 --- /dev/null +++ b/src/inference/dataset/lz4.h @@ -0,0 +1,764 @@ +/* + * LZ4 - Fast LZ compression algorithm + * Header File + * Copyright (C) 2011-present, Yann Collet. + + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + You can contact the author at : + - LZ4 homepage : http://www.lz4.org + - LZ4 source repository : https://github.com/lz4/lz4 +*/ +#if defined (__cplusplus) +extern "C" { +#endif + +#ifndef LZ4_H_2983827168210 +#define LZ4_H_2983827168210 + +/* --- Dependency --- */ +#include /* size_t */ + + +/** + Introduction + + LZ4 is lossless compression algorithm, providing compression speed >500 MB/s per core, + scalable with multi-cores CPU. It features an extremely fast decoder, with speed in + multiple GB/s per core, typically reaching RAM speed limits on multi-core systems. + + The LZ4 compression library provides in-memory compression and decompression functions. + It gives full buffer control to user. + Compression can be done in: + - a single step (described as Simple Functions) + - a single step, reusing a context (described in Advanced Functions) + - unbounded multiple steps (described as Streaming compression) + + lz4.h generates and decodes LZ4-compressed blocks (doc/lz4_Block_format.md). + Decompressing such a compressed block requires additional metadata. + Exact metadata depends on exact decompression function. + For the typical case of LZ4_decompress_safe(), + metadata includes block's compressed size, and maximum bound of decompressed size. + Each application is free to encode and pass such metadata in whichever way it wants. + + lz4.h only handle blocks, it can not generate Frames. + + Blocks are different from Frames (doc/lz4_Frame_format.md). + Frames bundle both blocks and metadata in a specified manner. + Embedding metadata is required for compressed data to be self-contained and portable. + Frame format is delivered through a companion API, declared in lz4frame.h. + The `lz4` CLI can only manage frames. +*/ + +/*^*************************************************************** +* Export parameters +*****************************************************************/ +/* +* LZ4_DLL_EXPORT : +* Enable exporting of functions when building a Windows DLL +* LZ4LIB_VISIBILITY : +* Control library symbols visibility. +*/ +#ifndef LZ4LIB_VISIBILITY +# if defined(__GNUC__) && (__GNUC__ >= 4) +# define LZ4LIB_VISIBILITY __attribute__ ((visibility ("default"))) +# else +# define LZ4LIB_VISIBILITY +# endif +#endif +#if defined(LZ4_DLL_EXPORT) && (LZ4_DLL_EXPORT==1) +# define LZ4LIB_API __declspec(dllexport) LZ4LIB_VISIBILITY +#elif defined(LZ4_DLL_IMPORT) && (LZ4_DLL_IMPORT==1) +# define LZ4LIB_API __declspec(dllimport) LZ4LIB_VISIBILITY /* It isn't required but allows to generate better code, saving a function pointer load from the IAT and an indirect jump.*/ +#else +# define LZ4LIB_API LZ4LIB_VISIBILITY +#endif + +/*------ Version ------*/ +#define LZ4_VERSION_MAJOR 1 /* for breaking interface changes */ +#define LZ4_VERSION_MINOR 9 /* for new (non-breaking) interface capabilities */ +#define LZ4_VERSION_RELEASE 2 /* for tweaks, bug-fixes, or development */ + +#define LZ4_VERSION_NUMBER (LZ4_VERSION_MAJOR *100*100 + LZ4_VERSION_MINOR *100 + LZ4_VERSION_RELEASE) + +#define LZ4_LIB_VERSION LZ4_VERSION_MAJOR.LZ4_VERSION_MINOR.LZ4_VERSION_RELEASE +#define LZ4_QUOTE(str) #str +#define LZ4_EXPAND_AND_QUOTE(str) LZ4_QUOTE(str) +#define LZ4_VERSION_STRING LZ4_EXPAND_AND_QUOTE(LZ4_LIB_VERSION) + +LZ4LIB_API int LZ4_versionNumber (void); /**< library version number; useful to check dll version */ +LZ4LIB_API const char* LZ4_versionString (void); /**< library version std::string; useful to check dll version */ + + +/*-************************************ +* Tuning parameter +**************************************/ +/*! + * LZ4_MEMORY_USAGE : + * Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.) + * Increasing memory usage improves compression ratio. + * Reduced memory usage may improve speed, thanks to better cache locality. + * Default value is 14, for 16KB, which nicely fits into Intel x86 L1 cache + */ +#ifndef LZ4_MEMORY_USAGE +# define LZ4_MEMORY_USAGE 14 +#endif + + +/*-************************************ +* Simple Functions +**************************************/ +/*! LZ4_compress_default() : + * Compresses 'srcSize' bytes from buffer 'src' + * into already allocated 'dst' buffer of size 'dstCapacity'. + * Compression is guaranteed to succeed if 'dstCapacity' >= LZ4_compressBound(srcSize). + * It also runs faster, so it's a recommended setting. + * If the function cannot compress 'src' into a more limited 'dst' budget, + * compression stops *immediately*, and the function result is zero. + * In which case, 'dst' content is undefined (invalid). + * srcSize : max supported value is LZ4_MAX_INPUT_SIZE. + * dstCapacity : size of buffer 'dst' (which must be already allocated) + * @return : the number of bytes written into buffer 'dst' (necessarily <= dstCapacity) + * or 0 if compression fails + * Note : This function is protected against buffer overflow scenarios (never writes outside 'dst' buffer, nor read outside 'source' buffer). + */ +LZ4LIB_API int LZ4_compress_default(const char* src, char* dst, int srcSize, int dstCapacity); + +/*! LZ4_decompress_safe() : + * compressedSize : is the exact complete size of the compressed block. + * dstCapacity : is the size of destination buffer (which must be already allocated), presumed an upper bound of decompressed size. + * @return : the number of bytes decompressed into destination buffer (necessarily <= dstCapacity) + * If destination buffer is not large enough, decoding will stop and output an error code (negative value). + * If the source stream is detected malformed, the function will stop decoding and return a negative result. + * Note 1 : This function is protected against malicious data packets : + * it will never writes outside 'dst' buffer, nor read outside 'source' buffer, + * even if the compressed block is maliciously modified to order the decoder to do these actions. + * In such case, the decoder stops immediately, and considers the compressed block malformed. + * Note 2 : compressedSize and dstCapacity must be provided to the function, the compressed block does not contain them. + * The implementation is free to send / store / derive this information in whichever way is most beneficial. + * If there is a need for a different format which bundles together both compressed data and its metadata, consider looking at lz4frame.h instead. + */ +LZ4LIB_API int LZ4_decompress_safe (const char* src, char* dst, int compressedSize, int dstCapacity); + + +/*-************************************ +* Advanced Functions +**************************************/ +#define LZ4_MAX_INPUT_SIZE 0x7E000000 /* 2 113 929 216 bytes */ +#define LZ4_COMPRESSBOUND(isize) ((unsigned)(isize) > (unsigned)LZ4_MAX_INPUT_SIZE ? 0 : (isize) + ((isize)/255) + 16) + +/*! LZ4_compressBound() : + Provides the maximum size that LZ4 compression may output in a "worst case" scenario (input data not compressible) + This function is primarily useful for memory allocation purposes (destination buffer size). + Macro LZ4_COMPRESSBOUND() is also provided for compilation-time evaluation (stack memory allocation for example). + Note that LZ4_compress_default() compresses faster when dstCapacity is >= LZ4_compressBound(srcSize) + inputSize : max supported value is LZ4_MAX_INPUT_SIZE + return : maximum output size in a "worst case" scenario + or 0, if input size is incorrect (too large or negative) +*/ +LZ4LIB_API int LZ4_compressBound(int inputSize); + +/*! LZ4_compress_fast() : + Same as LZ4_compress_default(), but allows selection of "acceleration" factor. + The larger the acceleration value, the faster the algorithm, but also the lesser the compression. + It's a trade-off. It can be fine tuned, with each successive value providing roughly +~3% to speed. + An acceleration value of "1" is the same as regular LZ4_compress_default() + Values <= 0 will be replaced by ACCELERATION_DEFAULT (currently == 1, see lz4.c). +*/ +LZ4LIB_API int LZ4_compress_fast (const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + + +/*! LZ4_compress_fast_extState() : + * Same as LZ4_compress_fast(), using an externally allocated memory space for its state. + * Use LZ4_sizeofState() to know how much memory must be allocated, + * and allocate it on 8-bytes boundaries (using `malloc()` typically). + * Then, provide this buffer as `void* state` to compression function. + */ +LZ4LIB_API int LZ4_sizeofState(void); +LZ4LIB_API int LZ4_compress_fast_extState (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + + +/*! LZ4_compress_destSize() : + * Reverse the logic : compresses as much data as possible from 'src' buffer + * into already allocated buffer 'dst', of size >= 'targetDestSize'. + * This function either compresses the entire 'src' content into 'dst' if it's large enough, + * or fill 'dst' buffer completely with as much data as possible from 'src'. + * note: acceleration parameter is fixed to "default". + * + * *srcSizePtr : will be modified to indicate how many bytes where read from 'src' to fill 'dst'. + * New value is necessarily <= input value. + * @return : Nb bytes written into 'dst' (necessarily <= targetDestSize) + * or 0 if compression fails. +*/ +LZ4LIB_API int LZ4_compress_destSize (const char* src, char* dst, int* srcSizePtr, int targetDstSize); + + +/*! LZ4_decompress_safe_partial() : + * Decompress an LZ4 compressed block, of size 'srcSize' at position 'src', + * into destination buffer 'dst' of size 'dstCapacity'. + * Up to 'targetOutputSize' bytes will be decoded. + * The function stops decoding on reaching this objective, + * which can boost performance when only the beginning of a block is required. + * + * @return : the number of bytes decoded in `dst` (necessarily <= dstCapacity) + * If source stream is detected malformed, function returns a negative result. + * + * Note : @return can be < targetOutputSize, if compressed block contains less data. + * + * Note 2 : this function features 2 parameters, targetOutputSize and dstCapacity, + * and expects targetOutputSize <= dstCapacity. + * It effectively stops decoding on reaching targetOutputSize, + * so dstCapacity is kind of redundant. + * This is because in a previous version of this function, + * decoding operation would not "break" a sequence in the middle. + * As a consequence, there was no guarantee that decoding would stop at exactly targetOutputSize, + * it could write more bytes, though only up to dstCapacity. + * Some "margin" used to be required for this operation to work properly. + * This is no longer necessary. + * The function nonetheless keeps its signature, in an effort to not break API. + */ +LZ4LIB_API int LZ4_decompress_safe_partial (const char* src, char* dst, int srcSize, int targetOutputSize, int dstCapacity); + + +/*-********************************************* +* Streaming Compression Functions +***********************************************/ +typedef union LZ4_stream_u LZ4_stream_t; /* incomplete type (defined later) */ + +LZ4LIB_API LZ4_stream_t* LZ4_createStream(void); +LZ4LIB_API int LZ4_freeStream (LZ4_stream_t* streamPtr); + +/*! LZ4_resetStream_fast() : v1.9.0+ + * Use this to prepare an LZ4_stream_t for a new chain of dependent blocks + * (e.g., LZ4_compress_fast_continue()). + * + * An LZ4_stream_t must be initialized once before usage. + * This is automatically done when created by LZ4_createStream(). + * However, should the LZ4_stream_t be simply declared on stack (for example), + * it's necessary to initialize it first, using LZ4_initStream(). + * + * After init, start any new stream with LZ4_resetStream_fast(). + * A same LZ4_stream_t can be re-used multiple times consecutively + * and compress multiple streams, + * provided that it starts each new stream with LZ4_resetStream_fast(). + * + * LZ4_resetStream_fast() is much faster than LZ4_initStream(), + * but is not compatible with memory regions containing garbage data. + * + * Note: it's only useful to call LZ4_resetStream_fast() + * in the context of streaming compression. + * The *extState* functions perform their own resets. + * Invoking LZ4_resetStream_fast() before is redundant, and even counterproductive. + */ +LZ4LIB_API void LZ4_resetStream_fast (LZ4_stream_t* streamPtr); + +/*! LZ4_loadDict() : + * Use this function to reference a static dictionary into LZ4_stream_t. + * The dictionary must remain available during compression. + * LZ4_loadDict() triggers a reset, so any previous data will be forgotten. + * The same dictionary will have to be loaded on decompression side for successful decoding. + * Dictionary are useful for better compression of small data (KB range). + * While LZ4 accept any input as dictionary, + * results are generally better when using Zstandard's Dictionary Builder. + * Loading a size of 0 is allowed, and is the same as reset. + * @return : loaded dictionary size, in bytes (necessarily <= 64 KB) + */ +LZ4LIB_API int LZ4_loadDict (LZ4_stream_t* streamPtr, const char* dictionary, int dictSize); + +/*! LZ4_compress_fast_continue() : + * Compress 'src' content using data from previously compressed blocks, for better compression ratio. + * 'dst' buffer must be already allocated. + * If dstCapacity >= LZ4_compressBound(srcSize), compression is guaranteed to succeed, and runs faster. + * + * @return : size of compressed block + * or 0 if there is an error (typically, cannot fit into 'dst'). + * + * Note 1 : Each invocation to LZ4_compress_fast_continue() generates a new block. + * Each block has precise boundaries. + * Each block must be decompressed separately, calling LZ4_decompress_*() with relevant metadata. + * It's not possible to append blocks together and expect a single invocation of LZ4_decompress_*() to decompress them together. + * + * Note 2 : The previous 64KB of source data is __assumed__ to remain present, unmodified, at same address in memory ! + * + * Note 3 : When input is structured as a double-buffer, each buffer can have any size, including < 64 KB. + * Make sure that buffers are separated, by at least one byte. + * This construction ensures that each block only depends on previous block. + * + * Note 4 : If input buffer is a ring-buffer, it can have any size, including < 64 KB. + * + * Note 5 : After an error, the stream status is undefined (invalid), it can only be reset or freed. + */ +LZ4LIB_API int LZ4_compress_fast_continue (LZ4_stream_t* streamPtr, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + +/*! LZ4_saveDict() : + * If last 64KB data cannot be guaranteed to remain available at its current memory location, + * save it into a safer place (char* safeBuffer). + * This is schematically equivalent to a memcpy() followed by LZ4_loadDict(), + * but is much faster, because LZ4_saveDict() doesn't need to rebuild tables. + * @return : saved dictionary size in bytes (necessarily <= maxDictSize), or 0 if error. + */ +LZ4LIB_API int LZ4_saveDict (LZ4_stream_t* streamPtr, char* safeBuffer, int maxDictSize); + + +/*-********************************************** +* Streaming Decompression Functions +* Bufferless synchronous API +************************************************/ +typedef union LZ4_streamDecode_u LZ4_streamDecode_t; /* tracking context */ + +/*! LZ4_createStreamDecode() and LZ4_freeStreamDecode() : + * creation / destruction of streaming decompression tracking context. + * A tracking context can be re-used multiple times. + */ +LZ4LIB_API LZ4_streamDecode_t* LZ4_createStreamDecode(void); +LZ4LIB_API int LZ4_freeStreamDecode (LZ4_streamDecode_t* LZ4_stream); + +/*! LZ4_setStreamDecode() : + * An LZ4_streamDecode_t context can be allocated once and re-used multiple times. + * Use this function to start decompression of a new stream of blocks. + * A dictionary can optionally be set. Use NULL or size 0 for a reset order. + * Dictionary is presumed stable : it must remain accessible and unmodified during next decompression. + * @return : 1 if OK, 0 if error + */ +LZ4LIB_API int LZ4_setStreamDecode (LZ4_streamDecode_t* LZ4_streamDecode, const char* dictionary, int dictSize); + +/*! LZ4_decoderRingBufferSize() : v1.8.2+ + * Note : in a ring buffer scenario (optional), + * blocks are presumed decompressed next to each other + * up to the moment there is not enough remaining space for next block (remainingSize < maxBlockSize), + * at which stage it resumes from beginning of ring buffer. + * When setting such a ring buffer for streaming decompression, + * provides the minimum size of this ring buffer + * to be compatible with any source respecting maxBlockSize condition. + * @return : minimum ring buffer size, + * or 0 if there is an error (invalid maxBlockSize). + */ +LZ4LIB_API int LZ4_decoderRingBufferSize(int maxBlockSize); +#define LZ4_DECODER_RING_BUFFER_SIZE(maxBlockSize) (65536 + 14 + (maxBlockSize)) /* for static allocation; maxBlockSize presumed valid */ + +/*! LZ4_decompress_*_continue() : + * These decoding functions allow decompression of consecutive blocks in "streaming" mode. + * A block is an unsplittable entity, it must be presented entirely to a decompression function. + * Decompression functions only accepts one block at a time. + * The last 64KB of previously decoded data *must* remain available and unmodified at the memory position where they were decoded. + * If less than 64KB of data has been decoded, all the data must be present. + * + * Special : if decompression side sets a ring buffer, it must respect one of the following conditions : + * - Decompression buffer size is _at least_ LZ4_decoderRingBufferSize(maxBlockSize). + * maxBlockSize is the maximum size of any single block. It can have any value > 16 bytes. + * In which case, encoding and decoding buffers do not need to be synchronized. + * Actually, data can be produced by any source compliant with LZ4 format specification, and respecting maxBlockSize. + * - Synchronized mode : + * Decompression buffer size is _exactly_ the same as compression buffer size, + * and follows exactly same update rule (block boundaries at same positions), + * and decoding function is provided with exact decompressed size of each block (exception for last block of the stream), + * _then_ decoding & encoding ring buffer can have any size, including small ones ( < 64 KB). + * - Decompression buffer is larger than encoding buffer, by a minimum of maxBlockSize more bytes. + * In which case, encoding and decoding buffers do not need to be synchronized, + * and encoding ring buffer can have any size, including small ones ( < 64 KB). + * + * Whenever these conditions are not possible, + * save the last 64KB of decoded data into a safe buffer where it can't be modified during decompression, + * then indicate where this data is saved using LZ4_setStreamDecode(), before decompressing next block. +*/ +LZ4LIB_API int LZ4_decompress_safe_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int srcSize, int dstCapacity); + + +/*! LZ4_decompress_*_usingDict() : + * These decoding functions work the same as + * a combination of LZ4_setStreamDecode() followed by LZ4_decompress_*_continue() + * They are stand-alone, and don't need an LZ4_streamDecode_t structure. + * Dictionary is presumed stable : it must remain accessible and unmodified during decompression. + * Performance tip : Decompression speed can be substantially increased + * when dst == dictStart + dictSize. + */ +LZ4LIB_API int LZ4_decompress_safe_usingDict (const char* src, char* dst, int srcSize, int dstCapcity, const char* dictStart, int dictSize); + +#endif /* LZ4_H_2983827168210 */ + + +/*^************************************* + * !!!!!! STATIC LINKING ONLY !!!!!! + ***************************************/ + +/*-**************************************************************************** + * Experimental section + * + * Symbols declared in this section must be considered unstable. Their + * signatures or semantics may change, or they may be removed altogether in the + * future. They are therefore only safe to depend on when the caller is + * statically linked against the library. + * + * To protect against unsafe usage, not only are the declarations guarded, + * the definitions are hidden by default + * when building LZ4 as a shared/dynamic library. + * + * In order to access these declarations, + * define LZ4_STATIC_LINKING_ONLY in your application + * before including LZ4's headers. + * + * In order to make their implementations accessible dynamically, you must + * define LZ4_PUBLISH_STATIC_FUNCTIONS when building the LZ4 library. + ******************************************************************************/ + +#ifdef LZ4_STATIC_LINKING_ONLY + +#ifndef LZ4_STATIC_3504398509 +#define LZ4_STATIC_3504398509 + +#ifdef LZ4_PUBLISH_STATIC_FUNCTIONS +#define LZ4LIB_STATIC_API LZ4LIB_API +#else +#define LZ4LIB_STATIC_API +#endif + + +/*! LZ4_compress_fast_extState_fastReset() : + * A variant of LZ4_compress_fast_extState(). + * + * Using this variant avoids an expensive initialization step. + * It is only safe to call if the state buffer is known to be correctly initialized already + * (see above comment on LZ4_resetStream_fast() for a definition of "correctly initialized"). + * From a high level, the difference is that + * this function initializes the provided state with a call to something like LZ4_resetStream_fast() + * while LZ4_compress_fast_extState() starts with a call to LZ4_resetStream(). + */ +LZ4LIB_STATIC_API int LZ4_compress_fast_extState_fastReset (void* state, const char* src, char* dst, int srcSize, int dstCapacity, int acceleration); + +/*! LZ4_attach_dictionary() : + * This is an experimental API that allows + * efficient use of a static dictionary many times. + * + * Rather than re-loading the dictionary buffer into a working context before + * each compression, or copying a pre-loaded dictionary's LZ4_stream_t into a + * working LZ4_stream_t, this function introduces a no-copy setup mechanism, + * in which the working stream references the dictionary stream in-place. + * + * Several assumptions are made about the state of the dictionary stream. + * Currently, only streams which have been prepared by LZ4_loadDict() should + * be expected to work. + * + * Alternatively, the provided dictionaryStream may be NULL, + * in which case any existing dictionary stream is unset. + * + * If a dictionary is provided, it replaces any pre-existing stream history. + * The dictionary contents are the only history that can be referenced and + * logically immediately precede the data compressed in the first subsequent + * compression call. + * + * The dictionary will only remain attached to the working stream through the + * first compression call, at the end of which it is cleared. The dictionary + * stream (and source buffer) must remain in-place / accessible / unchanged + * through the completion of the first compression call on the stream. + */ +LZ4LIB_STATIC_API void LZ4_attach_dictionary(LZ4_stream_t* workingStream, const LZ4_stream_t* dictionaryStream); + + +/*! In-place compression and decompression + * + * It's possible to have input and output sharing the same buffer, + * for highly contrained memory environments. + * In both cases, it requires input to lay at the end of the buffer, + * and decompression to start at beginning of the buffer. + * Buffer size must feature some margin, hence be larger than final size. + * + * |<------------------------buffer--------------------------------->| + * |<-----------compressed data--------->| + * |<-----------decompressed size------------------>| + * |<----margin---->| + * + * This technique is more useful for decompression, + * since decompressed size is typically larger, + * and margin is short. + * + * In-place decompression will work inside any buffer + * which size is >= LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize). + * This presumes that decompressedSize > compressedSize. + * Otherwise, it means compression actually expanded data, + * and it would be more efficient to store such data with a flag indicating it's not compressed. + * This can happen when data is not compressible (already compressed, or encrypted). + * + * For in-place compression, margin is larger, as it must be able to cope with both + * history preservation, requiring input data to remain unmodified up to LZ4_DISTANCE_MAX, + * and data expansion, which can happen when input is not compressible. + * As a consequence, buffer size requirements are much higher, + * and memory savings offered by in-place compression are more limited. + * + * There are ways to limit this cost for compression : + * - Reduce history size, by modifying LZ4_DISTANCE_MAX. + * Note that it is a compile-time constant, so all compressions will apply this limit. + * Lower values will reduce compression ratio, except when input_size < LZ4_DISTANCE_MAX, + * so it's a reasonable trick when inputs are known to be small. + * - Require the compressor to deliver a "maximum compressed size". + * This is the `dstCapacity` parameter in `LZ4_compress*()`. + * When this size is < LZ4_COMPRESSBOUND(inputSize), then compression can fail, + * in which case, the return code will be 0 (zero). + * The caller must be ready for these cases to happen, + * and typically design a backup scheme to send data uncompressed. + * The combination of both techniques can significantly reduce + * the amount of margin required for in-place compression. + * + * In-place compression can work in any buffer + * which size is >= (maxCompressedSize) + * with maxCompressedSize == LZ4_COMPRESSBOUND(srcSize) for guaranteed compression success. + * LZ4_COMPRESS_INPLACE_BUFFER_SIZE() depends on both maxCompressedSize and LZ4_DISTANCE_MAX, + * so it's possible to reduce memory requirements by playing with them. + */ + +#define LZ4_DECOMPRESS_INPLACE_MARGIN(compressedSize) (((compressedSize) >> 8) + 32) +#define LZ4_DECOMPRESS_INPLACE_BUFFER_SIZE(decompressedSize) ((decompressedSize) + LZ4_DECOMPRESS_INPLACE_MARGIN(decompressedSize)) /**< note: presumes that compressedSize < decompressedSize. note2: margin is overestimated a bit, since it could use compressedSize instead */ + +#ifndef LZ4_DISTANCE_MAX /* history window size; can be user-defined at compile time */ +# define LZ4_DISTANCE_MAX 65535 /* set to maximum value by default */ +#endif + +#define LZ4_COMPRESS_INPLACE_MARGIN (LZ4_DISTANCE_MAX + 32) /* LZ4_DISTANCE_MAX can be safely replaced by srcSize when it's smaller */ +#define LZ4_COMPRESS_INPLACE_BUFFER_SIZE(maxCompressedSize) ((maxCompressedSize) + LZ4_COMPRESS_INPLACE_MARGIN) /**< maxCompressedSize is generally LZ4_COMPRESSBOUND(inputSize), but can be set to any lower value, with the risk that compression can fail (return code 0(zero)) */ + +#endif /* LZ4_STATIC_3504398509 */ +#endif /* LZ4_STATIC_LINKING_ONLY */ + + + +#ifndef LZ4_H_98237428734687 +#define LZ4_H_98237428734687 + +/*-************************************************************ + * PRIVATE DEFINITIONS + ************************************************************** + * Do not use these definitions directly. + * They are only exposed to allow static allocation of `LZ4_stream_t` and `LZ4_streamDecode_t`. + * Accessing members will expose code to API and/or ABI break in future versions of the library. + **************************************************************/ +#define LZ4_HASHLOG (LZ4_MEMORY_USAGE-2) +#define LZ4_HASHTABLESIZE (1 << LZ4_MEMORY_USAGE) +#define LZ4_HASH_SIZE_U32 (1 << LZ4_HASHLOG) /* required as macro for static allocation */ + +#if defined(__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +#include + +typedef struct LZ4_stream_t_internal LZ4_stream_t_internal; +struct LZ4_stream_t_internal { + uint32_t hashTable[LZ4_HASH_SIZE_U32]; + uint32_t currentOffset; + uint16_t dirty; + uint16_t tableType; + const uint8_t* dictionary; + const LZ4_stream_t_internal* dictCtx; + uint32_t dictSize; +}; + +typedef struct { + const uint8_t* externalDict; + size_t extDictSize; + const uint8_t* prefixEnd; + size_t prefixSize; +} LZ4_streamDecode_t_internal; + +#else + +typedef struct LZ4_stream_t_internal LZ4_stream_t_internal; +struct LZ4_stream_t_internal { + unsigned int hashTable[LZ4_HASH_SIZE_U32]; + unsigned int currentOffset; + unsigned short dirty; + unsigned short tableType; + const unsigned char* dictionary; + const LZ4_stream_t_internal* dictCtx; + unsigned int dictSize; +}; + +typedef struct { + const unsigned char* externalDict; + const unsigned char* prefixEnd; + size_t extDictSize; + size_t prefixSize; +} LZ4_streamDecode_t_internal; + +#endif + +/*! LZ4_stream_t : + * information structure to track an LZ4 stream. + * LZ4_stream_t can also be created using LZ4_createStream(), which is recommended. + * The structure definition can be convenient for static allocation + * (on stack, or as part of larger structure). + * Init this structure with LZ4_initStream() before first use. + * note : only use this definition in association with static linking ! + * this definition is not API/ABI safe, and may change in a future version. + */ +#define LZ4_STREAMSIZE_U64 ((1 << (LZ4_MEMORY_USAGE-3)) + 4 + ((sizeof(void*)==16) ? 4 : 0) /*AS-400*/ ) +#define LZ4_STREAMSIZE (LZ4_STREAMSIZE_U64 * sizeof(unsigned long long)) +union LZ4_stream_u { + unsigned long long table[LZ4_STREAMSIZE_U64]; + LZ4_stream_t_internal internal_donotuse; +} ; /* previously typedef'd to LZ4_stream_t */ + +/*! LZ4_initStream() : v1.9.0+ + * An LZ4_stream_t structure must be initialized at least once. + * This is automatically done when invoking LZ4_createStream(), + * but it's not when the structure is simply declared on stack (for example). + * + * Use LZ4_initStream() to properly initialize a newly declared LZ4_stream_t. + * It can also initialize any arbitrary buffer of sufficient size, + * and will @return a pointer of proper type upon initialization. + * + * Note : initialization fails if size and alignment conditions are not respected. + * In which case, the function will @return NULL. + * Note2: An LZ4_stream_t structure guarantees correct alignment and size. + * Note3: Before v1.9.0, use LZ4_resetStream() instead + */ +LZ4LIB_API LZ4_stream_t* LZ4_initStream (void* buffer, size_t size); + + +/*! LZ4_streamDecode_t : + * information structure to track an LZ4 stream during decompression. + * init this structure using LZ4_setStreamDecode() before first use. + * note : only use in association with static linking ! + * this definition is not API/ABI safe, + * and may change in a future version ! + */ +#define LZ4_STREAMDECODESIZE_U64 (4 + ((sizeof(void*)==16) ? 2 : 0) /*AS-400*/ ) +#define LZ4_STREAMDECODESIZE (LZ4_STREAMDECODESIZE_U64 * sizeof(unsigned long long)) +union LZ4_streamDecode_u { + unsigned long long table[LZ4_STREAMDECODESIZE_U64]; + LZ4_streamDecode_t_internal internal_donotuse; +} ; /* previously typedef'd to LZ4_streamDecode_t */ + + + +/*-************************************ +* Obsolete Functions +**************************************/ + +/*! Deprecation warnings + * + * Deprecated functions make the compiler generate a warning when invoked. + * This is meant to invite users to update their source code. + * Should deprecation warnings be a problem, it is generally possible to disable them, + * typically with -Wno-deprecated-declarations for gcc + * or _CRT_SECURE_NO_WARNINGS in Visual. + * + * Another method is to define LZ4_DISABLE_DEPRECATE_WARNINGS + * before including the header file. + */ +#ifdef LZ4_DISABLE_DEPRECATE_WARNINGS +# define LZ4_DEPRECATED(message) /* disable deprecation warnings */ +#else +# define LZ4_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +# if defined (__cplusplus) && (__cplusplus >= 201402) /* C++14 or greater */ +# define LZ4_DEPRECATED(message) [[deprecated(message)]] +# elif (LZ4_GCC_VERSION >= 405) || defined(__clang__) +# define LZ4_DEPRECATED(message) __attribute__((deprecated(message))) +# elif (LZ4_GCC_VERSION >= 301) +# define LZ4_DEPRECATED(message) __attribute__((deprecated)) +# elif defined(_MSC_VER) +# define LZ4_DEPRECATED(message) __declspec(deprecated(message)) +# else +# pragma message("WARNING: You need to implement LZ4_DEPRECATED for this compiler") +# define LZ4_DEPRECATED(message) +# endif +#endif /* LZ4_DISABLE_DEPRECATE_WARNINGS */ + +/* Obsolete compression functions */ +LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress (const char* src, char* dest, int srcSize); +LZ4_DEPRECATED("use LZ4_compress_default() instead") LZ4LIB_API int LZ4_compress_limitedOutput (const char* src, char* dest, int srcSize, int maxOutputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_withState (void* state, const char* source, char* dest, int inputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_extState() instead") LZ4LIB_API int LZ4_compress_limitedOutput_withState (void* state, const char* source, char* dest, int inputSize, int maxOutputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize); +LZ4_DEPRECATED("use LZ4_compress_fast_continue() instead") LZ4LIB_API int LZ4_compress_limitedOutput_continue (LZ4_stream_t* LZ4_streamPtr, const char* source, char* dest, int inputSize, int maxOutputSize); + +/* Obsolete decompression functions */ +LZ4_DEPRECATED("use LZ4_decompress_fast() instead") LZ4LIB_API int LZ4_uncompress (const char* source, char* dest, int outputSize); +LZ4_DEPRECATED("use LZ4_decompress_safe() instead") LZ4LIB_API int LZ4_uncompress_unknownOutputSize (const char* source, char* dest, int isize, int maxOutputSize); + +/* Obsolete streaming functions; degraded functionality; do not use! + * + * In order to perform streaming compression, these functions depended on data + * that is no longer tracked in the state. They have been preserved as well as + * possible: using them will still produce a correct output. However, they don't + * actually retain any history between compression calls. The compression ratio + * achieved will therefore be no better than compressing each chunk + * independently. + */ +LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API void* LZ4_create (char* inputBuffer); +LZ4_DEPRECATED("Use LZ4_createStream() instead") LZ4LIB_API int LZ4_sizeofStreamState(void); +LZ4_DEPRECATED("Use LZ4_resetStream() instead") LZ4LIB_API int LZ4_resetStreamState(void* state, char* inputBuffer); +LZ4_DEPRECATED("Use LZ4_saveDict() instead") LZ4LIB_API char* LZ4_slideInputBuffer (void* state); + +/* Obsolete streaming decoding functions */ +LZ4_DEPRECATED("use LZ4_decompress_safe_usingDict() instead") LZ4LIB_API int LZ4_decompress_safe_withPrefix64k (const char* src, char* dst, int compressedSize, int maxDstSize); +LZ4_DEPRECATED("use LZ4_decompress_fast_usingDict() instead") LZ4LIB_API int LZ4_decompress_fast_withPrefix64k (const char* src, char* dst, int originalSize); + +/*! LZ4_decompress_fast() : **unsafe!** + * These functions used to be faster than LZ4_decompress_safe(), + * but it has changed, and they are now slower than LZ4_decompress_safe(). + * This is because LZ4_decompress_fast() doesn't know the input size, + * and therefore must progress more cautiously in the input buffer to not read beyond the end of block. + * On top of that `LZ4_decompress_fast()` is not protected vs malformed or malicious inputs, making it a security liability. + * As a consequence, LZ4_decompress_fast() is strongly discouraged, and deprecated. + * + * The last remaining LZ4_decompress_fast() specificity is that + * it can decompress a block without knowing its compressed size. + * Such functionality could be achieved in a more secure manner, + * by also providing the maximum size of input buffer, + * but it would require new prototypes, and adaptation of the implementation to this new use case. + * + * Parameters: + * originalSize : is the uncompressed size to regenerate. + * `dst` must be already allocated, its size must be >= 'originalSize' bytes. + * @return : number of bytes read from source buffer (== compressed size). + * The function expects to finish at block's end exactly. + * If the source stream is detected malformed, the function stops decoding and returns a negative result. + * note : LZ4_decompress_fast*() requires originalSize. Thanks to this information, it never writes past the output buffer. + * However, since it doesn't know its 'src' size, it may read an unknown amount of input, past input buffer bounds. + * Also, since match offsets are not validated, match reads from 'src' may underflow too. + * These issues never happen if input (compressed) data is correct. + * But they may happen if input data is invalid (error or intentional tampering). + * As a consequence, use these functions in trusted environments with trusted data **only**. + */ + +LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe() instead") +LZ4LIB_API int LZ4_decompress_fast (const char* src, char* dst, int originalSize); +LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_continue() instead") +LZ4LIB_API int LZ4_decompress_fast_continue (LZ4_streamDecode_t* LZ4_streamDecode, const char* src, char* dst, int originalSize); +LZ4_DEPRECATED("This function is deprecated and unsafe. Consider using LZ4_decompress_safe_usingDict() instead") +LZ4LIB_API int LZ4_decompress_fast_usingDict (const char* src, char* dst, int originalSize, const char* dictStart, int dictSize); + +/*! LZ4_resetStream() : + * An LZ4_stream_t structure must be initialized at least once. + * This is done with LZ4_initStream(), or LZ4_resetStream(). + * Consider switching to LZ4_initStream(), + * invoking LZ4_resetStream() will trigger deprecation warnings in the future. + */ +LZ4LIB_API void LZ4_resetStream (LZ4_stream_t* streamPtr); + + +#endif /* LZ4_H_98237428734687 */ + + +#if defined (__cplusplus) +} +#endif diff --git a/src/inference/enum/constants.h b/src/inference/enum/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..c4817f145eb1c1bc35b6613e46a4ee506800cf21 --- /dev/null +++ b/src/inference/enum/constants.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +// START OF NAMESPACE +namespace enums { + +const int SAFE_TRACK_MAP[15] = {0,1,2,3,4,5,6,7,8,10,11,12,13,14,15}; + +} +// END OF NAMESPACE \ No newline at end of file diff --git a/src/inference/enum/density.h b/src/inference/enum/density.h new file mode 100644 index 0000000000000000000000000000000000000000..fcc5cbb2d3950b21c62155f746e2845a14ba1514 --- /dev/null +++ b/src/inference/enum/density.h @@ -0,0 +1,142 @@ +#pragma once + +#include +#include + +// START OF NAMESPACE +namespace enums { + +std::unordered_map> DENSITY_QUANTILES = { + {0,{2,3,4,5,6,8,10,12,17,1073741824}}, + {1,{3,4,6,7,9,11,12,16,20,1073741824}}, + {2,{3,4,5,6,8,10,12,16,20,1073741824}}, + {3,{2,4,5,6,8,9,12,15,19,1073741824}}, + {4,{3,3,4,6,6,8,9,12,15,1073741824}}, + {5,{3,3,4,5,6,8,9,12,15,1073741824}}, + {6,{1,2,3,4,5,7,9,12,16,1073741824}}, + {7,{1,1,2,2,3,4,6,8,13,1073741824}}, + {8,{1,2,3,4,5,6,8,10,15,1073741824}}, + {9,{1,2,2,3,4,4,6,8,11,1073741824}}, + {10,{1,2,3,4,5,6,7,8,12,1073741824}}, + {11,{1,2,3,4,4,5,6,7,10,1073741824}}, + {12,{3,4,6,6,8,10,12,15,18,1073741824}}, + {13,{2,3,4,5,6,8,10,12,16,1073741824}}, + {14,{1,1,1,2,2,3,4,5,8,1073741824}}, + {15,{2,3,4,6,7,8,12,16,18,1073741824}}, + {16,{2,3,3,4,5,6,8,10,15,1073741824}}, + {17,{2,3,3,4,5,6,8,12,16,1073741824}}, + {18,{2,3,3,4,5,6,8,10,15,1073741824}}, + {19,{2,2,3,4,5,6,8,10,15,1073741824}}, + {20,{2,2,3,4,5,6,7,9,12,1073741824}}, + {21,{2,3,4,5,6,8,9,12,16,1073741824}}, + {22,{1,2,3,3,4,5,6,7,9,1073741824}}, + {23,{2,3,4,5,6,8,9,12,16,1073741824}}, + {24,{2,4,5,6,8,9,11,14,19,1073741824}}, + {25,{4,6,7,8,11,13,16,21,28,1073741824}}, + {26,{2,3,4,5,6,6,8,10,15,1073741824}}, + {27,{3,4,5,6,8,9,12,14,19,1073741824}}, + {28,{3,4,6,7,8,9,12,15,16,1073741824}}, + {29,{2,3,4,5,6,8,10,13,18,1073741824}}, + {30,{2,3,4,6,7,9,12,15,20,1073741824}}, + {31,{1,2,2,3,4,5,6,8,12,1073741824}}, + {32,{1,2,2,3,4,4,4,5,7,1073741824}}, + {33,{2,3,3,4,4,5,6,8,8,1073741824}}, + {34,{2,3,4,4,6,6,8,8,10,1073741824}}, + {35,{1,2,3,4,4,4,5,6,8,1073741824}}, + {36,{2,3,4,5,5,6,7,8,10,1073741824}}, + {37,{2,3,4,4,5,6,7,8,10,1073741824}}, + {38,{3,4,5,6,8,8,9,12,16,1073741824}}, + {39,{2,3,4,5,6,7,8,8,11,1073741824}}, + {40,{1,2,3,4,4,5,6,8,11,1073741824}}, + {41,{1,2,3,3,4,4,6,7,10,1073741824}}, + {42,{1,1,2,3,3,4,5,6,8,1073741824}}, + {43,{1,1,2,2,3,4,4,6,8,1073741824}}, + {44,{1,2,2,3,4,5,6,8,12,1073741824}}, + {45,{1,2,3,4,5,6,8,12,16,1073741824}}, + {46,{2,3,4,5,6,8,8,12,16,1073741824}}, + {47,{1,1,2,2,3,4,5,8,14,1073741824}}, + {48,{1,2,2,3,4,4,6,8,11,1073741824}}, + {49,{1,2,2,3,3,4,4,6,8,1073741824}}, + {50,{1,2,3,3,3,4,5,6,8,1073741824}}, + {51,{1,2,3,3,3,4,5,6,8,1073741824}}, + {52,{1,2,2,3,4,4,6,7,10,1073741824}}, + {53,{1,2,3,4,4,5,6,8,10,1073741824}}, + {54,{1,2,3,3,4,4,6,7,9,1073741824}}, + {55,{1,1,2,2,3,4,6,7,11,1073741824}}, + {56,{1,2,3,3,4,4,6,7,9,1073741824}}, + {57,{1,2,2,3,4,4,5,6,8,1073741824}}, + {58,{1,2,2,3,3,4,4,6,8,1073741824}}, + {59,{1,2,3,3,4,5,6,7,9,1073741824}}, + {60,{1,2,2,3,3,4,5,6,8,1073741824}}, + {61,{2,2,3,4,5,6,8,10,14,1073741824}}, + {62,{2,3,4,4,6,6,8,10,15,1073741824}}, + {63,{1,2,3,4,4,6,6,8,12,1073741824}}, + {64,{1,2,3,3,4,4,5,6,8,1073741824}}, + {65,{1,2,3,3,4,5,6,6,8,1073741824}}, + {66,{1,2,3,3,4,5,6,6,8,1073741824}}, + {67,{1,2,2,3,3,4,5,6,8,1073741824}}, + {68,{1,2,2,3,4,4,5,6,8,1073741824}}, + {69,{1,2,2,3,4,4,5,6,8,1073741824}}, + {70,{1,2,2,3,3,4,5,6,8,1073741824}}, + {71,{1,2,3,3,4,5,6,7,9,1073741824}}, + {72,{1,2,3,4,4,5,6,7,10,1073741824}}, + {73,{1,2,3,3,4,5,6,7,9,1073741824}}, + {74,{1,2,3,3,4,5,5,7,8,1073741824}}, + {75,{1,2,3,4,4,5,6,7,8,1073741824}}, + {76,{1,2,3,4,5,6,6,8,11,1073741824}}, + {77,{1,2,3,3,4,5,6,6,8,1073741824}}, + {78,{1,2,2,3,4,4,5,6,8,1073741824}}, + {79,{1,2,3,3,4,5,6,8,10,1073741824}}, + {80,{2,3,3,4,5,6,8,9,14,1073741824}}, + {81,{2,3,4,5,6,8,9,12,16,1073741824}}, + {82,{1,2,3,4,4,5,6,7,9,1073741824}}, + {83,{1,2,3,4,4,6,6,8,12,1073741824}}, + {84,{2,2,3,4,5,6,8,10,15,1073741824}}, + {85,{1,2,3,3,4,4,5,6,8,1073741824}}, + {86,{1,2,3,3,4,5,6,8,12,1073741824}}, + {87,{2,3,4,5,6,8,8,12,16,1073741824}}, + {88,{1,2,3,3,4,5,6,7,9,1073741824}}, + {89,{1,2,2,3,3,3,4,5,6,1073741824}}, + {90,{2,3,3,4,5,6,8,10,15,1073741824}}, + {91,{1,2,3,3,4,4,5,6,8,1073741824}}, + {92,{1,2,2,3,3,4,5,6,8,1073741824}}, + {93,{1,1,2,2,3,3,4,6,9,1073741824}}, + {94,{1,2,3,3,4,4,5,7,9,1073741824}}, + {95,{1,1,2,2,3,3,4,5,7,1073741824}}, + {96,{2,3,3,4,5,6,8,12,16,1073741824}}, + {97,{1,1,1,2,3,3,4,6,9,1073741824}}, + {98,{1,2,2,3,4,5,6,8,11,1073741824}}, + {99,{2,3,3,4,5,6,8,9,14,1073741824}}, + {100,{1,2,3,3,4,5,6,7,10,1073741824}}, + {101,{1,1,1,2,2,3,4,6,9,1073741824}}, + {102,{1,2,3,3,4,5,6,8,11,1073741824}}, + {103,{1,2,3,3,4,5,6,7,10,1073741824}}, + {104,{1,2,3,4,5,6,8,9,13,1073741824}}, + {105,{4,6,6,8,8,10,12,16,18,1073741824}}, + {106,{2,4,5,7,8,10,12,16,18,1073741824}}, + {107,{2,3,4,5,6,8,9,12,18,1073741824}}, + {108,{3,4,6,7,8,10,12,15,18,1073741824}}, + {109,{1,2,3,3,4,5,6,8,12,1073741824}}, + {110,{2,3,4,4,6,6,7,8,12,1073741824}}, + {111,{1,2,3,4,4,5,6,8,11,1073741824}}, + {112,{1,2,3,4,4,6,8,10,16,1073741824}}, + {113,{1,2,3,4,5,6,8,10,14,1073741824}}, + {114,{2,4,4,5,6,8,10,12,16,1073741824}}, + {115,{2,2,3,4,5,6,8,10,15,1073741824}}, + {116,{1,2,2,3,4,4,6,8,12,1073741824}}, + {117,{1,2,3,4,4,6,7,9,13,1073741824}}, + {118,{1,2,2,3,4,4,5,7,10,1073741824}}, + {119,{1,1,1,1,1,1,2,2,4,1073741824}}, + {120,{1,1,1,1,2,3,4,8,14,1073741824}}, + {121,{1,2,3,4,5,7,8,12,20,1073741824}}, + {122,{1,1,1,1,1,2,2,4,6,1073741824}}, + {123,{1,1,2,2,3,4,5,7,15,1073741824}}, + {124,{1,1,1,2,2,3,5,8,15,1073741824}}, + {125,{1,1,1,2,2,3,4,6,10,1073741824}}, + {126,{1,1,2,2,3,4,4,6,11,1073741824}}, + {127,{1,1,2,2,3,4,6,8,13,1073741824}}, + {128,{2,3,5,8,10,12,15,18,26,1073741824}} +}; + +} +// END OF NAMESPACE \ No newline at end of file diff --git a/src/inference/enum/encoder_types.h b/src/inference/enum/encoder_types.h new file mode 100644 index 0000000000000000000000000000000000000000..a34afe2e7c9c7a9863135906a3dea5bdf4fe9110 --- /dev/null +++ b/src/inference/enum/encoder_types.h @@ -0,0 +1,51 @@ +#pragma once + +#include "../../common/encoder/encoder_all.h" +#include + +namespace enums { + +enum ENCODER_TYPE { + EXPRESSIVE_ENCODER, + NO_ENCODER +}; + +std::unique_ptr getEncoder(ENCODER_TYPE et) { + switch (et) { + case EXPRESSIVE_ENCODER: return std::make_unique(); + case NO_ENCODER: return NULL; + } + return NULL; +} + +ENCODER_TYPE getEncoderType(const std::string &s) { + if (s == "EXPRESSIVE_ENCODER") return EXPRESSIVE_ENCODER; + return NO_ENCODER; +} + +std::vector getEncoderTypeList() { + std::vector list; + list.push_back("EXPRESSIVE_ENCODER"); + return list; +} + +int getEncoderSize(ENCODER_TYPE et) { + std::unique_ptr encoder = getEncoder(et); + if (!encoder) { + return 0; + } + int size = encoder->rep->max_token(); + return size; +} + +// helper for unit tests +inline bool starts_with(std::string const & value, std::string const & match) { + if (match.size() > value.size()) return false; + return std::equal(match.begin(), match.end(), value.begin()); +} + +std::unique_ptr getEncoderFromString(const std::string &s) { + return getEncoder(getEncoderType(s)); +} + +} diff --git a/src/inference/enum/gm.h b/src/inference/enum/gm.h new file mode 100644 index 0000000000000000000000000000000000000000..c93d23b28440a579a05ce43c1e559ac0527284d2 --- /dev/null +++ b/src/inference/enum/gm.h @@ -0,0 +1,707 @@ +#pragma once + +#include +#include +#include +#include "../../common/data_structures/track_type.h" + +// START OF NAMESPACE +namespace enums { + +std::map> GM_MOD = { + {midi::GM_TYPE::any,{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127}}, + {midi::GM_TYPE::piano,{0,1,2,3,4,5,6,7}}, + {midi::GM_TYPE::chromatic_perc,{9,10,11,12,13,14,15}}, + {midi::GM_TYPE::organ,{16,17,18,19,20,21,22,23}}, + {midi::GM_TYPE::guitar,{24,25,26,27,28,29,30,31}}, + {midi::GM_TYPE::bass,{32,33,34,35,36,37,38,39}}, + {midi::GM_TYPE::strings,{40,41,42,43,44,45,46,47}}, + {midi::GM_TYPE::ensemble,{48,49,50,51,52,53,54,55}}, + {midi::GM_TYPE::brass,{56,57,58,59,60,61,62,63}}, + {midi::GM_TYPE::reed,{64,65,66,67,68,69,70,71}}, + {midi::GM_TYPE::pipe,{72,73,74,75,76,77,78,79}}, + {midi::GM_TYPE::synth_lead,{80,81,82,83,84,85,86,87}}, + {midi::GM_TYPE::synth_pad,{88,89,90,91,92,93,94,95}}, + {midi::GM_TYPE::synth_effects,{96,97,98,99,100,101,102,103}}, + {midi::GM_TYPE::ethnic,{104,105,106,107,108,109,110,111}}, + {midi::GM_TYPE::percussive,{112,113,114,115,116,117,118,119}}, + {midi::GM_TYPE::sound_fx,{120,121,122,123,124,125,126,127}}, + {midi::GM_TYPE::no_drums,{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127}}, + {midi::GM_TYPE::drums,{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127}}, + {midi::GM_TYPE::acoustic_grand_piano,{0}}, + {midi::GM_TYPE::bright_acoustic_piano,{1}}, + {midi::GM_TYPE::electric_grand_piano,{2}}, + {midi::GM_TYPE::honky_tonk_piano,{3}}, + {midi::GM_TYPE::electric_piano_1,{4}}, + {midi::GM_TYPE::electric_piano_2,{5}}, + {midi::GM_TYPE::harpsichord,{6}}, + {midi::GM_TYPE::clavi,{7}}, + {midi::GM_TYPE::celesta,{8}}, + {midi::GM_TYPE::glockenspiel,{9}}, + {midi::GM_TYPE::music_box,{10}}, + {midi::GM_TYPE::vibraphone,{11}}, + {midi::GM_TYPE::marimba,{12}}, + {midi::GM_TYPE::xylophone,{13}}, + {midi::GM_TYPE::tubular_bells,{14}}, + {midi::GM_TYPE::dulcimer,{15}}, + {midi::GM_TYPE::drawbar_organ,{16}}, + {midi::GM_TYPE::percussive_organ,{17}}, + {midi::GM_TYPE::rock_organ,{18}}, + {midi::GM_TYPE::church_organ,{19}}, + {midi::GM_TYPE::reed_organ,{20}}, + {midi::GM_TYPE::accordion,{21}}, + {midi::GM_TYPE::harmonica,{22}}, + {midi::GM_TYPE::tango_accordion,{23}}, + {midi::GM_TYPE::acoustic_guitar_nylon,{24}}, + {midi::GM_TYPE::acoustic_guitar_steel,{25}}, + {midi::GM_TYPE::electric_guitar_jazz,{26}}, + {midi::GM_TYPE::electric_guitar_clean,{27}}, + {midi::GM_TYPE::electric_guitar_muted,{28}}, + {midi::GM_TYPE::overdriven_guitar,{29}}, + {midi::GM_TYPE::distortion_guitar,{30}}, + {midi::GM_TYPE::guitar_harmonics,{31}}, + {midi::GM_TYPE::acoustic_bass,{32}}, + {midi::GM_TYPE::electric_bass_finger,{33}}, + {midi::GM_TYPE::electric_bass_pick,{34}}, + {midi::GM_TYPE::fretless_bass,{35}}, + {midi::GM_TYPE::slap_bass_1,{36}}, + {midi::GM_TYPE::slap_bass_2,{37}}, + {midi::GM_TYPE::synth_bass_1,{38}}, + {midi::GM_TYPE::synth_bass_2,{39}}, + {midi::GM_TYPE::violin,{40}}, + {midi::GM_TYPE::viola,{41}}, + {midi::GM_TYPE::cello,{42}}, + {midi::GM_TYPE::contrabass,{43}}, + {midi::GM_TYPE::tremolo_strings,{44}}, + {midi::GM_TYPE::pizzicato_strings,{45}}, + {midi::GM_TYPE::orchestral_harp,{46}}, + {midi::GM_TYPE::timpani,{47}}, + {midi::GM_TYPE::string_ensemble_1,{48}}, + {midi::GM_TYPE::string_ensemble_2,{49}}, + {midi::GM_TYPE::synth_strings_1,{50}}, + {midi::GM_TYPE::synth_strings_2,{51}}, + {midi::GM_TYPE::choir_aahs,{52}}, + {midi::GM_TYPE::voice_oohs,{53}}, + {midi::GM_TYPE::synth_voice,{54}}, + {midi::GM_TYPE::orchestra_hit,{55}}, + {midi::GM_TYPE::trumpet,{56}}, + {midi::GM_TYPE::trombone,{57}}, + {midi::GM_TYPE::tuba,{58}}, + {midi::GM_TYPE::muted_trumpet,{59}}, + {midi::GM_TYPE::french_horn,{60}}, + {midi::GM_TYPE::brass_section,{61}}, + {midi::GM_TYPE::synth_brass_1,{62}}, + {midi::GM_TYPE::synth_brass_2,{63}}, + {midi::GM_TYPE::soprano_sax,{64}}, + {midi::GM_TYPE::alto_sax,{65}}, + {midi::GM_TYPE::tenor_sax,{66}}, + {midi::GM_TYPE::baritone_sax,{67}}, + {midi::GM_TYPE::oboe,{68}}, + {midi::GM_TYPE::english_horn,{69}}, + {midi::GM_TYPE::bassoon,{70}}, + {midi::GM_TYPE::clarinet,{71}}, + {midi::GM_TYPE::piccolo,{72}}, + {midi::GM_TYPE::flute,{73}}, + {midi::GM_TYPE::recorder,{74}}, + {midi::GM_TYPE::pan_flute,{75}}, + {midi::GM_TYPE::blown_bottle,{76}}, + {midi::GM_TYPE::shakuhachi,{77}}, + {midi::GM_TYPE::whistle,{78}}, + {midi::GM_TYPE::ocarina,{79}}, + {midi::GM_TYPE::lead_1_square,{80}}, + {midi::GM_TYPE::lead_2_sawtooth,{81}}, + {midi::GM_TYPE::lead_3_calliope,{82}}, + {midi::GM_TYPE::lead_4_chiff,{83}}, + {midi::GM_TYPE::lead_5_charang,{84}}, + {midi::GM_TYPE::lead_6_voice,{85}}, + {midi::GM_TYPE::lead_7_fifths,{86}}, + {midi::GM_TYPE::lead_8_bass__lead,{87}}, + {midi::GM_TYPE::pad_1_new_age,{88}}, + {midi::GM_TYPE::pad_2_warm,{89}}, + {midi::GM_TYPE::pad_3_polysynth,{90}}, + {midi::GM_TYPE::pad_4_choir,{91}}, + {midi::GM_TYPE::pad_5_bowed,{92}}, + {midi::GM_TYPE::pad_6_metallic,{93}}, + {midi::GM_TYPE::pad_7_halo,{94}}, + {midi::GM_TYPE::pad_8_sweep,{95}}, + {midi::GM_TYPE::fx_1_rain,{96}}, + {midi::GM_TYPE::fx_2_soundtrack,{97}}, + {midi::GM_TYPE::fx_3_crystal,{98}}, + {midi::GM_TYPE::fx_4_atmosphere,{99}}, + {midi::GM_TYPE::fx_5_brightness,{100}}, + {midi::GM_TYPE::fx_6_goblins,{101}}, + {midi::GM_TYPE::fx_7_echoes,{102}}, + {midi::GM_TYPE::fx_8_sci_fi,{103}}, + {midi::GM_TYPE::sitar,{104}}, + {midi::GM_TYPE::banjo,{105}}, + {midi::GM_TYPE::shamisen,{106}}, + {midi::GM_TYPE::koto,{107}}, + {midi::GM_TYPE::kalimba,{108}}, + {midi::GM_TYPE::bag_pipe,{109}}, + {midi::GM_TYPE::fiddle,{110}}, + {midi::GM_TYPE::shanai,{111}}, + {midi::GM_TYPE::tinkle_bell,{112}}, + {midi::GM_TYPE::agogo,{113}}, + {midi::GM_TYPE::steel_drums,{114}}, + {midi::GM_TYPE::woodblock,{115}}, + {midi::GM_TYPE::taiko_drum,{116}}, + {midi::GM_TYPE::melodic_tom,{117}}, + {midi::GM_TYPE::synth_drum,{118}}, + {midi::GM_TYPE::reverse_cymbal,{119}}, + {midi::GM_TYPE::guitar_fret_noise,{120}}, + {midi::GM_TYPE::breath_noise,{121}}, + {midi::GM_TYPE::seashore,{122}}, + {midi::GM_TYPE::bird_tweet,{123}}, + {midi::GM_TYPE::telephone_ring,{124}}, + {midi::GM_TYPE::helicopter,{125}}, + {midi::GM_TYPE::applause,{126}}, + {midi::GM_TYPE::gunshot,{127}}, + {midi::GM_TYPE::drum_0,{0}}, + {midi::GM_TYPE::drum_1,{1}}, + {midi::GM_TYPE::drum_2,{2}}, + {midi::GM_TYPE::drum_3,{3}}, + {midi::GM_TYPE::drum_4,{4}}, + {midi::GM_TYPE::drum_5,{5}}, + {midi::GM_TYPE::drum_6,{6}}, + {midi::GM_TYPE::drum_7,{7}}, + {midi::GM_TYPE::drum_8,{8}}, + {midi::GM_TYPE::drum_9,{9}}, + {midi::GM_TYPE::drum_10,{10}}, + {midi::GM_TYPE::drum_11,{11}}, + {midi::GM_TYPE::drum_12,{12}}, + {midi::GM_TYPE::drum_13,{13}}, + {midi::GM_TYPE::drum_14,{14}}, + {midi::GM_TYPE::drum_15,{15}}, + {midi::GM_TYPE::drum_16,{16}}, + {midi::GM_TYPE::drum_17,{17}}, + {midi::GM_TYPE::drum_18,{18}}, + {midi::GM_TYPE::drum_19,{19}}, + {midi::GM_TYPE::drum_20,{20}}, + {midi::GM_TYPE::drum_21,{21}}, + {midi::GM_TYPE::drum_22,{22}}, + {midi::GM_TYPE::drum_23,{23}}, + {midi::GM_TYPE::drum_24,{24}}, + {midi::GM_TYPE::drum_25,{25}}, + {midi::GM_TYPE::drum_26,{26}}, + {midi::GM_TYPE::drum_27,{27}}, + {midi::GM_TYPE::drum_28,{28}}, + {midi::GM_TYPE::drum_29,{29}}, + {midi::GM_TYPE::drum_30,{30}}, + {midi::GM_TYPE::drum_31,{31}}, + {midi::GM_TYPE::drum_32,{32}}, + {midi::GM_TYPE::drum_33,{33}}, + {midi::GM_TYPE::drum_34,{34}}, + {midi::GM_TYPE::drum_35,{35}}, + {midi::GM_TYPE::drum_36,{36}}, + {midi::GM_TYPE::drum_37,{37}}, + {midi::GM_TYPE::drum_38,{38}}, + {midi::GM_TYPE::drum_39,{39}}, + {midi::GM_TYPE::drum_40,{40}}, + {midi::GM_TYPE::drum_41,{41}}, + {midi::GM_TYPE::drum_42,{42}}, + {midi::GM_TYPE::drum_43,{43}}, + {midi::GM_TYPE::drum_44,{44}}, + {midi::GM_TYPE::drum_45,{45}}, + {midi::GM_TYPE::drum_46,{46}}, + {midi::GM_TYPE::drum_47,{47}}, + {midi::GM_TYPE::drum_48,{48}}, + {midi::GM_TYPE::drum_49,{49}}, + {midi::GM_TYPE::drum_50,{50}}, + {midi::GM_TYPE::drum_51,{51}}, + {midi::GM_TYPE::drum_52,{52}}, + {midi::GM_TYPE::drum_53,{53}}, + {midi::GM_TYPE::drum_54,{54}}, + {midi::GM_TYPE::drum_55,{55}}, + {midi::GM_TYPE::drum_56,{56}}, + {midi::GM_TYPE::drum_57,{57}}, + {midi::GM_TYPE::drum_58,{58}}, + {midi::GM_TYPE::drum_59,{59}}, + {midi::GM_TYPE::drum_60,{60}}, + {midi::GM_TYPE::drum_61,{61}}, + {midi::GM_TYPE::drum_62,{62}}, + {midi::GM_TYPE::drum_63,{63}}, + {midi::GM_TYPE::drum_64,{64}}, + {midi::GM_TYPE::drum_65,{65}}, + {midi::GM_TYPE::drum_66,{66}}, + {midi::GM_TYPE::drum_67,{67}}, + {midi::GM_TYPE::drum_68,{68}}, + {midi::GM_TYPE::drum_69,{69}}, + {midi::GM_TYPE::drum_70,{70}}, + {midi::GM_TYPE::drum_71,{71}}, + {midi::GM_TYPE::drum_72,{72}}, + {midi::GM_TYPE::drum_73,{73}}, + {midi::GM_TYPE::drum_74,{74}}, + {midi::GM_TYPE::drum_75,{75}}, + {midi::GM_TYPE::drum_76,{76}}, + {midi::GM_TYPE::drum_77,{77}}, + {midi::GM_TYPE::drum_78,{78}}, + {midi::GM_TYPE::drum_79,{79}}, + {midi::GM_TYPE::drum_80,{80}}, + {midi::GM_TYPE::drum_81,{81}}, + {midi::GM_TYPE::drum_82,{82}}, + {midi::GM_TYPE::drum_83,{83}}, + {midi::GM_TYPE::drum_84,{84}}, + {midi::GM_TYPE::drum_85,{85}}, + {midi::GM_TYPE::drum_86,{86}}, + {midi::GM_TYPE::drum_87,{87}}, + {midi::GM_TYPE::drum_88,{88}}, + {midi::GM_TYPE::drum_89,{89}}, + {midi::GM_TYPE::drum_90,{90}}, + {midi::GM_TYPE::drum_91,{91}}, + {midi::GM_TYPE::drum_92,{92}}, + {midi::GM_TYPE::drum_93,{93}}, + {midi::GM_TYPE::drum_94,{94}}, + {midi::GM_TYPE::drum_95,{95}}, + {midi::GM_TYPE::drum_96,{96}}, + {midi::GM_TYPE::drum_97,{97}}, + {midi::GM_TYPE::drum_98,{98}}, + {midi::GM_TYPE::drum_99,{99}}, + {midi::GM_TYPE::drum_100,{100}}, + {midi::GM_TYPE::drum_101,{101}}, + {midi::GM_TYPE::drum_102,{102}}, + {midi::GM_TYPE::drum_103,{103}}, + {midi::GM_TYPE::drum_104,{104}}, + {midi::GM_TYPE::drum_105,{105}}, + {midi::GM_TYPE::drum_106,{106}}, + {midi::GM_TYPE::drum_107,{107}}, + {midi::GM_TYPE::drum_108,{108}}, + {midi::GM_TYPE::drum_109,{109}}, + {midi::GM_TYPE::drum_110,{110}}, + {midi::GM_TYPE::drum_111,{111}}, + {midi::GM_TYPE::drum_112,{112}}, + {midi::GM_TYPE::drum_113,{113}}, + {midi::GM_TYPE::drum_114,{114}}, + {midi::GM_TYPE::drum_115,{115}}, + {midi::GM_TYPE::drum_116,{116}}, + {midi::GM_TYPE::drum_117,{117}}, + {midi::GM_TYPE::drum_118,{118}}, + {midi::GM_TYPE::drum_119,{119}}, + {midi::GM_TYPE::drum_120,{120}}, + {midi::GM_TYPE::drum_121,{121}}, + {midi::GM_TYPE::drum_122,{122}}, + {midi::GM_TYPE::drum_123,{123}}, + {midi::GM_TYPE::drum_124,{124}}, + {midi::GM_TYPE::drum_125,{125}}, + {midi::GM_TYPE::drum_126,{126}}, + {midi::GM_TYPE::drum_127,{127}}, +}; + +std::map GM_REV = { + {0,midi::GM_TYPE::acoustic_grand_piano}, + {1,midi::GM_TYPE::bright_acoustic_piano}, + {2,midi::GM_TYPE::electric_grand_piano}, + {3,midi::GM_TYPE::honky_tonk_piano}, + {4,midi::GM_TYPE::electric_piano_1}, + {5,midi::GM_TYPE::electric_piano_2}, + {6,midi::GM_TYPE::harpsichord}, + {7,midi::GM_TYPE::clavi}, + {8,midi::GM_TYPE::celesta}, + {9,midi::GM_TYPE::glockenspiel}, + {10,midi::GM_TYPE::music_box}, + {11,midi::GM_TYPE::vibraphone}, + {12,midi::GM_TYPE::marimba}, + {13,midi::GM_TYPE::xylophone}, + {14,midi::GM_TYPE::tubular_bells}, + {15,midi::GM_TYPE::dulcimer}, + {16,midi::GM_TYPE::drawbar_organ}, + {17,midi::GM_TYPE::percussive_organ}, + {18,midi::GM_TYPE::rock_organ}, + {19,midi::GM_TYPE::church_organ}, + {20,midi::GM_TYPE::reed_organ}, + {21,midi::GM_TYPE::accordion}, + {22,midi::GM_TYPE::harmonica}, + {23,midi::GM_TYPE::tango_accordion}, + {24,midi::GM_TYPE::acoustic_guitar_nylon}, + {25,midi::GM_TYPE::acoustic_guitar_steel}, + {26,midi::GM_TYPE::electric_guitar_jazz}, + {27,midi::GM_TYPE::electric_guitar_clean}, + {28,midi::GM_TYPE::electric_guitar_muted}, + {29,midi::GM_TYPE::overdriven_guitar}, + {30,midi::GM_TYPE::distortion_guitar}, + {31,midi::GM_TYPE::guitar_harmonics}, + {32,midi::GM_TYPE::acoustic_bass}, + {33,midi::GM_TYPE::electric_bass_finger}, + {34,midi::GM_TYPE::electric_bass_pick}, + {35,midi::GM_TYPE::fretless_bass}, + {36,midi::GM_TYPE::slap_bass_1}, + {37,midi::GM_TYPE::slap_bass_2}, + {38,midi::GM_TYPE::synth_bass_1}, + {39,midi::GM_TYPE::synth_bass_2}, + {40,midi::GM_TYPE::violin}, + {41,midi::GM_TYPE::viola}, + {42,midi::GM_TYPE::cello}, + {43,midi::GM_TYPE::contrabass}, + {44,midi::GM_TYPE::tremolo_strings}, + {45,midi::GM_TYPE::pizzicato_strings}, + {46,midi::GM_TYPE::orchestral_harp}, + {47,midi::GM_TYPE::timpani}, + {48,midi::GM_TYPE::string_ensemble_1}, + {49,midi::GM_TYPE::string_ensemble_2}, + {50,midi::GM_TYPE::synth_strings_1}, + {51,midi::GM_TYPE::synth_strings_2}, + {52,midi::GM_TYPE::choir_aahs}, + {53,midi::GM_TYPE::voice_oohs}, + {54,midi::GM_TYPE::synth_voice}, + {55,midi::GM_TYPE::orchestra_hit}, + {56,midi::GM_TYPE::trumpet}, + {57,midi::GM_TYPE::trombone}, + {58,midi::GM_TYPE::tuba}, + {59,midi::GM_TYPE::muted_trumpet}, + {60,midi::GM_TYPE::french_horn}, + {61,midi::GM_TYPE::brass_section}, + {62,midi::GM_TYPE::synth_brass_1}, + {63,midi::GM_TYPE::synth_brass_2}, + {64,midi::GM_TYPE::soprano_sax}, + {65,midi::GM_TYPE::alto_sax}, + {66,midi::GM_TYPE::tenor_sax}, + {67,midi::GM_TYPE::baritone_sax}, + {68,midi::GM_TYPE::oboe}, + {69,midi::GM_TYPE::english_horn}, + {70,midi::GM_TYPE::bassoon}, + {71,midi::GM_TYPE::clarinet}, + {72,midi::GM_TYPE::piccolo}, + {73,midi::GM_TYPE::flute}, + {74,midi::GM_TYPE::recorder}, + {75,midi::GM_TYPE::pan_flute}, + {76,midi::GM_TYPE::blown_bottle}, + {77,midi::GM_TYPE::shakuhachi}, + {78,midi::GM_TYPE::whistle}, + {79,midi::GM_TYPE::ocarina}, + {80,midi::GM_TYPE::lead_1_square}, + {81,midi::GM_TYPE::lead_2_sawtooth}, + {82,midi::GM_TYPE::lead_3_calliope}, + {83,midi::GM_TYPE::lead_4_chiff}, + {84,midi::GM_TYPE::lead_5_charang}, + {85,midi::GM_TYPE::lead_6_voice}, + {86,midi::GM_TYPE::lead_7_fifths}, + {87,midi::GM_TYPE::lead_8_bass__lead}, + {88,midi::GM_TYPE::pad_1_new_age}, + {89,midi::GM_TYPE::pad_2_warm}, + {90,midi::GM_TYPE::pad_3_polysynth}, + {91,midi::GM_TYPE::pad_4_choir}, + {92,midi::GM_TYPE::pad_5_bowed}, + {93,midi::GM_TYPE::pad_6_metallic}, + {94,midi::GM_TYPE::pad_7_halo}, + {95,midi::GM_TYPE::pad_8_sweep}, + {96,midi::GM_TYPE::fx_1_rain}, + {97,midi::GM_TYPE::fx_2_soundtrack}, + {98,midi::GM_TYPE::fx_3_crystal}, + {99,midi::GM_TYPE::fx_4_atmosphere}, + {100,midi::GM_TYPE::fx_5_brightness}, + {101,midi::GM_TYPE::fx_6_goblins}, + {102,midi::GM_TYPE::fx_7_echoes}, + {103,midi::GM_TYPE::fx_8_sci_fi}, + {104,midi::GM_TYPE::sitar}, + {105,midi::GM_TYPE::banjo}, + {106,midi::GM_TYPE::shamisen}, + {107,midi::GM_TYPE::koto}, + {108,midi::GM_TYPE::kalimba}, + {109,midi::GM_TYPE::bag_pipe}, + {110,midi::GM_TYPE::fiddle}, + {111,midi::GM_TYPE::shanai}, + {112,midi::GM_TYPE::tinkle_bell}, + {113,midi::GM_TYPE::agogo}, + {114,midi::GM_TYPE::steel_drums}, + {115,midi::GM_TYPE::woodblock}, + {116,midi::GM_TYPE::taiko_drum}, + {117,midi::GM_TYPE::melodic_tom}, + {118,midi::GM_TYPE::synth_drum}, + {119,midi::GM_TYPE::reverse_cymbal}, + {120,midi::GM_TYPE::guitar_fret_noise}, + {121,midi::GM_TYPE::breath_noise}, + {122,midi::GM_TYPE::seashore}, + {123,midi::GM_TYPE::bird_tweet}, + {124,midi::GM_TYPE::telephone_ring}, + {125,midi::GM_TYPE::helicopter}, + {126,midi::GM_TYPE::applause}, + {127,midi::GM_TYPE::gunshot}, + {128,midi::GM_TYPE::drum_0}, + {129,midi::GM_TYPE::drum_1}, + {130,midi::GM_TYPE::drum_2}, + {131,midi::GM_TYPE::drum_3}, + {132,midi::GM_TYPE::drum_4}, + {133,midi::GM_TYPE::drum_5}, + {134,midi::GM_TYPE::drum_6}, + {135,midi::GM_TYPE::drum_7}, + {136,midi::GM_TYPE::drum_8}, + {137,midi::GM_TYPE::drum_9}, + {138,midi::GM_TYPE::drum_10}, + {139,midi::GM_TYPE::drum_11}, + {140,midi::GM_TYPE::drum_12}, + {141,midi::GM_TYPE::drum_13}, + {142,midi::GM_TYPE::drum_14}, + {143,midi::GM_TYPE::drum_15}, + {144,midi::GM_TYPE::drum_16}, + {145,midi::GM_TYPE::drum_17}, + {146,midi::GM_TYPE::drum_18}, + {147,midi::GM_TYPE::drum_19}, + {148,midi::GM_TYPE::drum_20}, + {149,midi::GM_TYPE::drum_21}, + {150,midi::GM_TYPE::drum_22}, + {151,midi::GM_TYPE::drum_23}, + {152,midi::GM_TYPE::drum_24}, + {153,midi::GM_TYPE::drum_25}, + {154,midi::GM_TYPE::drum_26}, + {155,midi::GM_TYPE::drum_27}, + {156,midi::GM_TYPE::drum_28}, + {157,midi::GM_TYPE::drum_29}, + {158,midi::GM_TYPE::drum_30}, + {159,midi::GM_TYPE::drum_31}, + {160,midi::GM_TYPE::drum_32}, + {161,midi::GM_TYPE::drum_33}, + {162,midi::GM_TYPE::drum_34}, + {163,midi::GM_TYPE::drum_35}, + {164,midi::GM_TYPE::drum_36}, + {165,midi::GM_TYPE::drum_37}, + {166,midi::GM_TYPE::drum_38}, + {167,midi::GM_TYPE::drum_39}, + {168,midi::GM_TYPE::drum_40}, + {169,midi::GM_TYPE::drum_41}, + {170,midi::GM_TYPE::drum_42}, + {171,midi::GM_TYPE::drum_43}, + {172,midi::GM_TYPE::drum_44}, + {173,midi::GM_TYPE::drum_45}, + {174,midi::GM_TYPE::drum_46}, + {175,midi::GM_TYPE::drum_47}, + {176,midi::GM_TYPE::drum_48}, + {177,midi::GM_TYPE::drum_49}, + {178,midi::GM_TYPE::drum_50}, + {179,midi::GM_TYPE::drum_51}, + {180,midi::GM_TYPE::drum_52}, + {181,midi::GM_TYPE::drum_53}, + {182,midi::GM_TYPE::drum_54}, + {183,midi::GM_TYPE::drum_55}, + {184,midi::GM_TYPE::drum_56}, + {185,midi::GM_TYPE::drum_57}, + {186,midi::GM_TYPE::drum_58}, + {187,midi::GM_TYPE::drum_59}, + {188,midi::GM_TYPE::drum_60}, + {189,midi::GM_TYPE::drum_61}, + {190,midi::GM_TYPE::drum_62}, + {191,midi::GM_TYPE::drum_63}, + {192,midi::GM_TYPE::drum_64}, + {193,midi::GM_TYPE::drum_65}, + {194,midi::GM_TYPE::drum_66}, + {195,midi::GM_TYPE::drum_67}, + {196,midi::GM_TYPE::drum_68}, + {197,midi::GM_TYPE::drum_69}, + {198,midi::GM_TYPE::drum_70}, + {199,midi::GM_TYPE::drum_71}, + {200,midi::GM_TYPE::drum_72}, + {201,midi::GM_TYPE::drum_73}, + {202,midi::GM_TYPE::drum_74}, + {203,midi::GM_TYPE::drum_75}, + {204,midi::GM_TYPE::drum_76}, + {205,midi::GM_TYPE::drum_77}, + {206,midi::GM_TYPE::drum_78}, + {207,midi::GM_TYPE::drum_79}, + {208,midi::GM_TYPE::drum_80}, + {209,midi::GM_TYPE::drum_81}, + {210,midi::GM_TYPE::drum_82}, + {211,midi::GM_TYPE::drum_83}, + {212,midi::GM_TYPE::drum_84}, + {213,midi::GM_TYPE::drum_85}, + {214,midi::GM_TYPE::drum_86}, + {215,midi::GM_TYPE::drum_87}, + {216,midi::GM_TYPE::drum_88}, + {217,midi::GM_TYPE::drum_89}, + {218,midi::GM_TYPE::drum_90}, + {219,midi::GM_TYPE::drum_91}, + {220,midi::GM_TYPE::drum_92}, + {221,midi::GM_TYPE::drum_93}, + {222,midi::GM_TYPE::drum_94}, + {223,midi::GM_TYPE::drum_95}, + {224,midi::GM_TYPE::drum_96}, + {225,midi::GM_TYPE::drum_97}, + {226,midi::GM_TYPE::drum_98}, + {227,midi::GM_TYPE::drum_99}, + {228,midi::GM_TYPE::drum_100}, + {229,midi::GM_TYPE::drum_101}, + {230,midi::GM_TYPE::drum_102}, + {231,midi::GM_TYPE::drum_103}, + {232,midi::GM_TYPE::drum_104}, + {233,midi::GM_TYPE::drum_105}, + {234,midi::GM_TYPE::drum_106}, + {235,midi::GM_TYPE::drum_107}, + {236,midi::GM_TYPE::drum_108}, + {237,midi::GM_TYPE::drum_109}, + {238,midi::GM_TYPE::drum_110}, + {239,midi::GM_TYPE::drum_111}, + {240,midi::GM_TYPE::drum_112}, + {241,midi::GM_TYPE::drum_113}, + {242,midi::GM_TYPE::drum_114}, + {243,midi::GM_TYPE::drum_115}, + {244,midi::GM_TYPE::drum_116}, + {245,midi::GM_TYPE::drum_117}, + {246,midi::GM_TYPE::drum_118}, + {247,midi::GM_TYPE::drum_119}, + {248,midi::GM_TYPE::drum_120}, + {249,midi::GM_TYPE::drum_121}, + {250,midi::GM_TYPE::drum_122}, + {251,midi::GM_TYPE::drum_123}, + {252,midi::GM_TYPE::drum_124}, + {253,midi::GM_TYPE::drum_125}, + {254,midi::GM_TYPE::drum_126}, + {255,midi::GM_TYPE::drum_127}, +}; + +std::map gm_inst_to_category = { + {midi::GM_TYPE::acoustic_grand_piano, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::bright_acoustic_piano, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::electric_grand_piano, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::honky_tonk_piano, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::electric_piano_1, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::electric_piano_2, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::harpsichord, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::clavi, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::celesta, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::glockenspiel, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::music_box, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::vibraphone, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::marimba, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::xylophone, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::tubular_bells, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::dulcimer, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::drawbar_organ, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::percussive_organ, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::rock_organ, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::church_organ, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::reed_organ, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::accordion, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::harmonica, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::tango_accordion, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::acoustic_guitar_nylon, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::acoustic_guitar_steel, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::electric_guitar_jazz, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::electric_guitar_clean, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::electric_guitar_muted, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::overdriven_guitar, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::distortion_guitar, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::guitar_harmonics, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::acoustic_bass, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::electric_bass_finger, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::electric_bass_pick, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::fretless_bass, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::slap_bass_1, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::slap_bass_2, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::synth_bass_1, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::synth_bass_2, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::violin, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::viola, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::cello, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::contrabass, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::tremolo_strings, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pizzicato_strings, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::orchestral_harp, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::timpani, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::string_ensemble_1, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::string_ensemble_2, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::synth_strings_1, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::synth_strings_2, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::choir_aahs, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::voice_oohs, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::synth_voice, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::orchestra_hit, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::trumpet, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::trombone, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::tuba, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::muted_trumpet, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::french_horn, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::brass_section, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::synth_brass_1, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::synth_brass_2, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::soprano_sax, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::alto_sax, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::tenor_sax, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::baritone_sax, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::oboe, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::english_horn, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::bassoon, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::clarinet, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::piccolo, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::flute, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::recorder, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::pan_flute, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::blown_bottle, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::shakuhachi, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::whistle, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::ocarina, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_1_square, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_2_sawtooth, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_3_calliope, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_4_chiff, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_5_charang, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_6_voice, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_7_fifths, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::lead_8_bass__lead, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::pad_1_new_age, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_2_warm, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_3_polysynth, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_4_choir, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_5_bowed, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_6_metallic, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_7_halo, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::pad_8_sweep, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::fx_1_rain, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_2_soundtrack, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_3_crystal, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_4_atmosphere, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_5_brightness, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_6_goblins, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_7_echoes, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::fx_8_sci_fi, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::sitar, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::banjo, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::shamisen, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::koto, midi::GM_CATEGORY_POLY}, + {midi::GM_TYPE::kalimba, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::bag_pipe, midi::GM_CATEGORY_MONO}, // technically two tones? + {midi::GM_TYPE::fiddle, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::shanai, midi::GM_CATEGORY_MONO}, + {midi::GM_TYPE::tinkle_bell, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::agogo, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::steel_drums, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::woodblock, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::taiko_drum, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::melodic_tom, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::synth_drum, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::reverse_cymbal, midi::GM_CATEGORY_PERC}, + {midi::GM_TYPE::guitar_fret_noise, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::breath_noise, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::seashore, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::bird_tweet, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::telephone_ring, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::helicopter, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::applause, midi::GM_CATEGORY_SOUND_FX}, + {midi::GM_TYPE::gunshot, midi::GM_CATEGORY_SOUND_FX} +}; + +std::vector get_instruments_by_category(std::string category) { + std::vector result; + const google::protobuf::EnumDescriptor *descriptor = midi::GM_TYPE_descriptor(); + const google::protobuf::EnumDescriptor *category_descriptor = midi::GM_CATEGORY_descriptor(); + auto value = category_descriptor->FindValueByName(category); + if (value == NULL) { + return result; + } + for (auto kv : gm_inst_to_category) { + if (kv.second == value->number()) { + result.push_back( descriptor->FindValueByNumber(kv.first)->name() ); + } + } + return result; +} + +std::pair get_instrument_and_track_type_from_gm_inst(std::string gm_inst) { + const google::protobuf::EnumDescriptor *descriptor = midi::GM_TYPE_descriptor(); + auto value = descriptor->FindValueByName(gm_inst); + if (value == NULL) { + throw std::runtime_error("Invalid GM instrument name"); + } + int instrument_number = GM_MOD[static_cast(value->number())][0]; + return std::make_pair(instrument_number, instrument_number < 128 ? "STANDARD_TRACK" : "STANDARD_DRUM_TRACK"); +} + + +} +// END OF NAMESPACE diff --git a/src/inference/enum/model_type.h b/src/inference/enum/model_type.h new file mode 100644 index 0000000000000000000000000000000000000000..0203faebd60ae21d530ee73ca7dec28ac61c8e40 --- /dev/null +++ b/src/inference/enum/model_type.h @@ -0,0 +1,13 @@ +#pragma once + + +// START OF NAMESPACE +namespace enums { + +enum MODEL_TYPE { + TRACK_MODEL, + BAR_INFILL_MODEL +}; + +} +// END OF NAMESPACE \ No newline at end of file diff --git a/src/inference/enum/pretrain_group.h b/src/inference/enum/pretrain_group.h new file mode 100644 index 0000000000000000000000000000000000000000..a0f0a70de2e43adbe83c1599c78bbcbb8af51497 --- /dev/null +++ b/src/inference/enum/pretrain_group.h @@ -0,0 +1,271 @@ +#pragma once + +#include + +// START OF NAMESPACE +namespace enums { + +const std::map PRETRAIN_GROUPING = { + {0,0}, + {1,0}, + {2,0}, + {3,1}, + {4,2}, + {5,2}, + {6,3}, + {7,4}, + {8,5}, + {9,6}, + {10,7}, + {11,8}, + {12,9}, + {13,10}, + {14,11}, + {15,12}, + {16,13}, + {17,13}, + {18,13}, + {19,14}, + {20,14}, + {21,15}, + {22,16}, + {23,17}, + {24,18}, + {25,19}, + {26,20}, + {27,21}, + {28,22}, + {29,23}, + {30,24}, + {31,25}, + {32,26}, + {33,27}, + {34,27}, + {35,28}, + {36,29}, + {37,29}, + {38,30}, + {39,30}, + {40,31}, + {41,32}, + {42,33}, + {43,34}, + {44,35}, + {45,36}, + {46,37}, + {47,38}, + {48,39}, + {49,39}, + {50,40}, + {51,40}, + {52,41}, + {53,42}, + {54,43}, + {55,44}, + {56,45}, + {57,46}, + {58,47}, + {59,48}, + {60,49}, + {61,50}, + {62,51}, + {63,51}, + {64,52}, + {65,53}, + {66,54}, + {67,55}, + {68,56}, + {69,57}, + {70,58}, + {71,59}, + {72,60}, + {73,61}, + {74,62}, + {75,63}, + {76,64}, + {77,65}, + {78,66}, + {79,67}, + {80,68}, + {81,69}, + {82,70}, + {83,71}, + {84,72}, + {85,73}, + {86,74}, + {87,75}, + {88,76}, + {89,76}, + {90,76}, + {91,76}, + {92,76}, + {93,76}, + {94,76}, + {95,76}, + {96,77}, + {97,78}, + {98,79}, + {99,80}, + {100,81}, + {101,82}, + {102,83}, + {103,84}, + {104,85}, + {105,86}, + {106,87}, + {107,88}, + {108,89}, + {109,90}, + {110,91}, + {111,92}, + {112,93}, + {113,94}, + {114,95}, + {115,96}, + {116,97}, + {117,98}, + {118,99}, + {119,100}, + {120,101}, + {121,102}, + {122,103}, + {123,104}, + {124,105}, + {125,106}, + {126,107}, + {127,108}, +}; + +const std::map PRETRAIN_GROUPING_V2 = { + {0,0}, + {1,0}, + {2,0}, + {3,3}, + {4,4}, + {5,4}, + {6,6}, + {7,7}, + {8,8}, + {9,9}, + {10,10}, + {11,11}, + {12,12}, + {13,13}, + {14,14}, + {15,15}, + {16,16}, + {17,16}, + {18,16}, + {19,19}, + {20,19}, + {21,21}, + {22,22}, + {23,23}, + {24,24}, + {25,25}, + {26,26}, + {27,27}, + {28,28}, + {29,29}, + {30,30}, + {31,31}, + {32,32}, + {33,33}, + {34,33}, + {35,35}, + {36,36}, + {37,36}, + {38,38}, + {39,38}, + {40,40}, + {41,41}, + {42,42}, + {43,43}, + {44,44}, + {45,45}, + {46,46}, + {47,47}, + {48,48}, + {49,48}, + {50,50}, + {51,50}, + {52,52}, + {53,53}, + {54,54}, + {55,55}, + {56,56}, + {57,57}, + {58,58}, + {59,59}, + {60,60}, + {61,61}, + {62,62}, + {63,62}, + {64,64}, + {65,65}, + {66,66}, + {67,67}, + {68,68}, + {69,69}, + {70,70}, + {71,71}, + {72,72}, + {73,73}, + {74,74}, + {75,75}, + {76,76}, + {77,77}, + {78,78}, + {79,79}, + {80,80}, + {81,81}, + {82,82}, + {83,83}, + {84,84}, + {85,85}, + {86,86}, + {87,87}, + {88,88}, + {89,88}, + {90,88}, + {91,88}, + {92,88}, + {93,88}, + {94,88}, + {95,88}, + {96,96}, + {97,97}, + {98,98}, + {99,99}, + {100,100}, + {101,101}, + {102,102}, + {103,103}, + {104,104}, + {105,105}, + {106,106}, + {107,107}, + {108,108}, + {109,109}, + {110,110}, + {111,111}, + {112,112}, + {113,113}, + {114,114}, + {115,115}, + {116,116}, + {117,117}, + {118,118}, + {119,119}, + {120,120}, + {121,121}, + {122,122}, + {123,123}, + {124,124}, + {125,125}, + {126,126}, + {127,127}, +}; + +} +// END OF NAMESPACE \ No newline at end of file diff --git a/src/inference/enum/timesigs.h b/src/inference/enum/timesigs.h new file mode 100644 index 0000000000000000000000000000000000000000..4abee99160e250198e16cb819080887d97b2d3cd --- /dev/null +++ b/src/inference/enum/timesigs.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +namespace enums { + + const std::map, int> YELLOW_TS_MAP = { + {{4,4},0}, + {{3,4},1}, + {{2,4},2}, + {{6,8},3}, + {{2,2},4}, + {{1,4},5}, + {{6,4},6}, + {{3,8},7}, + {{5,4},8}, + {{4,2},9}, + {{1,8},10}, + {{3,2},11}, + {{9,8},12}, + {{5,8},13}, + {{7,8},14}, + {{12,8},15}, + {{8,4},16}, + {{7,4},17}, + {{4,8},18}, + {{3,1},19}, + {{1,2},20}, + {{8,8},21}, + {{11,8},22}, + {{2,8},23}, + {{6,2},24}, + {{9,4},25}, + {{2,1},26}, + {{9,16},27}, + {{12,4},28}, + {{10,4},29}, + {{13,16},30}, + {{15,16},31}, + {{17,16},32}, + {{1,16},33}, + {{10,8},34}, + {{16,4},35}, + }; + +} \ No newline at end of file diff --git a/src/inference/enum/velocity.h b/src/inference/enum/velocity.h new file mode 100644 index 0000000000000000000000000000000000000000..1b9a7665fa32bec64bc555c07286a06c12ea9f73 --- /dev/null +++ b/src/inference/enum/velocity.h @@ -0,0 +1,140 @@ +#pragma once +#include + +// START OF NAMESPACE +namespace enums { + + +const std::map DEFAULT_VELOCITY_MAP = { + {0,0}, + {1,1}, + {2,1}, + {3,1}, + {4,1}, + {5,2}, + {6,2}, + {7,2}, + {8,2}, + {9,3}, + {10,3}, + {11,3}, + {12,3}, + {13,4}, + {14,4}, + {15,4}, + {16,4}, + {17,5}, + {18,5}, + {19,5}, + {20,5}, + {21,6}, + {22,6}, + {23,6}, + {24,6}, + {25,7}, + {26,7}, + {27,7}, + {28,7}, + {29,8}, + {30,8}, + {31,8}, + {32,8}, + {33,8}, + {34,9}, + {35,9}, + {36,9}, + {37,9}, + {38,10}, + {39,10}, + {40,10}, + {41,10}, + {42,11}, + {43,11}, + {44,11}, + {45,11}, + {46,12}, + {47,12}, + {48,12}, + {49,12}, + {50,13}, + {51,13}, + {52,13}, + {53,13}, + {54,14}, + {55,14}, + {56,14}, + {57,14}, + {58,15}, + {59,15}, + {60,15}, + {61,15}, + {62,16}, + {63,16}, + {64,16}, + {65,16}, + {66,16}, + {67,17}, + {68,17}, + {69,17}, + {70,17}, + {71,18}, + {72,18}, + {73,18}, + {74,18}, + {75,19}, + {76,19}, + {77,19}, + {78,19}, + {79,20}, + {80,20}, + {81,20}, + {82,20}, + {83,21}, + {84,21}, + {85,21}, + {86,21}, + {87,22}, + {88,22}, + {89,22}, + {90,22}, + {91,23}, + {92,23}, + {93,23}, + {94,23}, + {95,24}, + {96,24}, + {97,24}, + {98,24}, + {99,24}, + {100,25}, + {101,25}, + {102,25}, + {103,25}, + {104,26}, + {105,26}, + {106,26}, + {107,26}, + {108,27}, + {109,27}, + {110,27}, + {111,27}, + {112,28}, + {113,28}, + {114,28}, + {115,28}, + {116,29}, + {117,29}, + {118,29}, + {119,29}, + {120,30}, + {121,30}, + {122,30}, + {123,30}, + {124,31}, + {125,31}, + {126,31}, + {127,31}, +}; + +} +// END OF NAMESPACE diff --git a/src/inference/protobuf/rsj.h b/src/inference/protobuf/rsj.h new file mode 100644 index 0000000000000000000000000000000000000000..24f88bb0d6981a2438b640575974fd22cf69a3b1 --- /dev/null +++ b/src/inference/protobuf/rsj.h @@ -0,0 +1,786 @@ +/** ************************************************************************************** +* * +* A Ridiculously Simple JSON Parser for C++ (RSJp-cpp) * +* Version 2.x * +* ---------------------------------------------------------- * +* Copyright (C) 2018 Subhrajit Bhattacharya * +* * +* This program is free software: you can redistribute it and/or modify * +* it under the terms of the GNU General Public License as published by * +* the Free Software Foundation, either version 3 of the License, or * +* (at your option) any later version. * +* * +* This program is distributed in the hope that it will be useful, * +* but WITHOUT ANY WARRANTY; without even the implied warranty of * +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * +* GNU General Public License for more details . * +* * +* * +* Contact: subhrajit@gmail.com * +* https://www.lehigh.edu/~sub216/ , http://subhrajit.net/ * +* * +* * +*************************************************************************************** **/ + +#ifndef __DOSL_RSJPARSE_TCC +#define __DOSL_RSJPARSE_TCC + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +static char const* RSJobjectbrackets = "{}"; +static char const* RSJarraybrackets = "[]"; +static char RSJobjectassignment = ':'; +static char RSJarraydelimiter = ','; + +static std::vector RSJbrackets = {RSJobjectbrackets, RSJarraybrackets}; +static std::vector RSJstringquotes = {"\"\"", "''"}; +static char RSJcharescape = '\\'; +static std::string RSJlinecommentstart = "//"; + +static std::string RSJprinttab = " "; + +enum RSJresourceType { RSJ_UNINITIATED, RSJ_UNKNOWN, RSJ_OBJECT, RSJ_ARRAY, RSJ_LEAF }; + +// ============================================================ +// Direct string manipulation functions + +inline +std::string to_string (RSJresourceType rt) { + switch (rt) { + case RSJ_UNINITIATED: return("RSJ_UNINITIATED"); + case RSJ_UNKNOWN: return("RSJ_UNKNOWN"); + case RSJ_OBJECT: return("RSJ_OBJECT"); + case RSJ_ARRAY: return("RSJ_ARRAY"); + case RSJ_LEAF: return("RSJ_LEAF"); + } +} + +enum StrTrimDir { STRTRIM_L=1, STRTRIM_R=2, STRTRIM_LR=3 }; + +inline +std::string strtrim (std::string str, std::string chars=" \t\n\r", int max_count=-1, StrTrimDir dirs=STRTRIM_LR) { + if (str.empty()) return(str); + if (max_count<0) max_count = str.length(); + + if (dirs & STRTRIM_L) { // left trim + int p; + for (p=0; p& bracks, int indx=0) { + for (int b=0; b<(int)bracks.size(); ++b) + if (c==bracks[b][indx]) + return (b); + return (-1); +} + +inline +std::vector split_RSJ_array (const std::string& str) { // TODO: Make efficient. This function is speed bottleneck. + // splits, while respecting brackets and escapes + std::vector ret; + + std::string current; + std::vector bracket_stack; + std::vector quote_stack; + bool escape_active = false; + int bi; + + for (int a=0; a<(int)str.length(); ++a) { // * + + // delimiter + if ( bracket_stack.size()==0 && quote_stack.size()==0 && str[a]==RSJarraydelimiter ) { + ret.push_back (current); + current.clear(); bracket_stack.clear(); quote_stack.clear(); escape_active = false; + continue; // to * + } + + // ------------------------------------ + // checks for string + + if (quote_stack.size() > 0) { // already inside string + if (str[a]==RSJcharescape) // an escape character + escape_active = !escape_active; + else if (!escape_active && str[a]==RSJstringquotes[quote_stack.back()][1] ) { // close quote + quote_stack.pop_back(); + escape_active = false; + } + else + escape_active = false; + + current.push_back (str[a]); + continue; // to * + } + + if (quote_stack.size()==0) { // check for start of string + if ((bi = is_bracket (str[a], RSJstringquotes)) >= 0) { + quote_stack.push_back (bi); + current.push_back (str[a]); + continue; // to * + } + } + + // ------------------------------------ + // checks for comments + + if (quote_stack.size()==0) { // comment cannot start inside string + + // single-line commenst + if (str.compare (a, RSJlinecommentstart.length(), RSJlinecommentstart) == 0) { + // ignore until end of line + int newline_pos = str.find ("\n", a); + if (newline_pos == (int)std::string::npos) + newline_pos = str.find ("\r", a); + + if (newline_pos != (int)std::string::npos) + a = newline_pos; // point to the newline character (a will be incremented) + else // the comment continues until EOF + a = str.length(); + continue; + } + } + + // ------------------------------------ + // checks for brackets + + if ( bracket_stack.size()>0 && str[a]==RSJbrackets[bracket_stack.back()][1] ) { // check for closing bracket + bracket_stack.pop_back(); + current.push_back (str[a]); + continue; + } + + if ((bi = is_bracket (str[a], RSJbrackets)) >= 0) { + bracket_stack.push_back (bi); + current.push_back (str[a]); + continue; // to * + } + + // ------------------------------------ + // otherwise + current.push_back (str[a]); + } + + if (current.length() > 0) + ret.push_back (current); + + return (ret); +} + +inline +std::string insert_tab_after_newlines (std::string str) { + for (int a=0; a<(int)str.length(); ++a) + if (str[a]=='\n') { + str.insert (a+1, RSJprinttab); + a += RSJprinttab.length(); + } + return (str); +} + + +// ============================================================ + +// forward declarations +class RSJparsedData; +class RSJresource; + +// Objet and array typedefs +typedef std::unordered_map RSJobject; +typedef std::vector RSJarray; + +// ------------------------------------ +// Main classes + +class RSJresource { +/* Use: RSJresource("RSJ_string_data").as()["keyName"].as()[2].as() + RSJresource("RSJ_string_data")["keyName"][2].as() */ +private: + // main data + std::string data; // can be object, vector or leaf data + bool _exists; // whether the RSJ resource exists. + + // parsed data + RSJparsedData* parsed_data_p; + +public: + // constructor + RSJresource () : _exists (false), parsed_data_p (NULL) { } // no data field. + + RSJresource (std::string str) : data (str), _exists (true), parsed_data_p (NULL) { } + RSJresource (const char* str) : RSJresource(std::string(str)) { } + + // other convertion + template + RSJresource (dataType d) : RSJresource(std::to_string(d)) { } + + // read from file and stream + RSJresource (std::istream& is) : _exists (true), parsed_data_p (NULL) { + data = std::string ( (std::istreambuf_iterator(is)), (std::istreambuf_iterator()) ); + } + RSJresource (std::ifstream& ifs) : _exists (true), parsed_data_p (NULL) { + std::istream& is = ifs; + data = std::string ( (std::istreambuf_iterator(is)), (std::istreambuf_iterator()) ); + } + + // free allocated memory for parsed data + ~RSJresource(); + + // deep copy + RSJresource (const RSJresource& r); + RSJresource& operator= (const RSJresource& r); + + // ------------------------------------ + // parsers (old) + RSJresourceType parse (bool force=false); + void parse_full (bool force=false, int max_depth=INT_MAX, int* parse_count_for_verbose_p=NULL); // recursively parse the entire JSON text + // parser (new) + void fast_parse (std::string* str_p=NULL, bool copy_string=false, int max_depth=INT_MAX, int* parse_start_str_pos=NULL); // TODO: finish. + + RSJobject& as_object (bool force=false); + RSJarray& as_array (bool force=false); + + // ------------------------------------ + + // access raw data and other attributes + int size(void); + std::string& raw_data (void) { return (data); } + bool exists (void) { return (_exists); } + bool is_parsed (void) { return (parsed_data_p!=NULL); } + RSJresourceType type (void); + // emitter + std::string as_str (bool print_comments=false, bool update_data=true); + void print (bool print_comments=false, bool update_data=true) + { std::cout << as_str(print_comments,update_data) << std::endl; } + + // opertor[] + RSJresource& operator[] (std::string key); // object + RSJresource& operator[] (int indx); // array + + // ------------------------------------ + + // as + template + dataType as (const dataType& def = dataType()) { // specialized outside class declaration + if (!exists()) return (def); + return dataType (data); // default behavior for unknown types: invoke 'dataType(std::string)' + } + + // as_vector + template > // vectorType should have push_back method + vectorType as_vector (const vectorType& def = vectorType()); + + // as_map + template > // mapType should have operator[] defined + mapType as_map (const mapType& def = mapType()); +}; + +// ------------------------------------------------------------ + +class RSJparsedData { +public: + RSJobject object; + RSJarray array; + + RSJresourceType type; + RSJparsedData() : type(RSJ_UNKNOWN) {} + + // parser (single-level) + void parse (const std::string& data, RSJresourceType typ = RSJ_UNKNOWN) { + std::string content = strtrim(data); + + if (typ==RSJ_OBJECT || typ==RSJ_UNKNOWN) { + // parse as object: + content = strtrim (strtrim (content, "{", 1, STRTRIM_L ), "}", 1, STRTRIM_R ); + if (content.length() != data.length()) { // a valid object + std::vector nvPairs = split_RSJ_array (content); + for (int a=0; a<(int)nvPairs.size(); ++a) { + std::size_t assignmentPos = nvPairs[a].find (RSJobjectassignment); + object.insert (make_pair( + strip_outer_quotes (nvPairs[a].substr (0,assignmentPos) ) , + RSJresource (strtrim (nvPairs[a].substr (assignmentPos+1) ) ) + ) ); + } + if (object.size() > 0) { + type = RSJ_OBJECT; + return; + } + } + } + + if (typ==RSJ_ARRAY || typ==RSJ_UNKNOWN) { + // parse as array + content = strtrim (strtrim (content, "[", 1, STRTRIM_L ), "]", 1, STRTRIM_R ); + if (content.length() != data.length()) { // a valid array + std::vector nvPairs = split_RSJ_array (content); + for (int a=0; a<(int)nvPairs.size(); ++a) + array.push_back (RSJresource (strtrim (nvPairs[a]) ) ); + if (array.size() > 0) { + type = RSJ_ARRAY; + return; + } + } + } + + if (typ==RSJ_UNKNOWN) + type = RSJ_LEAF; + } + + + // remove non-existing items inserted due to accessing + int cleanup(void) { + + if (type==RSJ_OBJECT) { + bool found = true; + while (found) { + found = false; + for (auto it=object.begin(); it!=object.end(); ++it) + if (!(it->second.exists())) { + object.erase(it); + found = true; + break; // break for loop since it is now invalid + } + } + return (object.size()); + } + + if (type==RSJ_ARRAY) { // erases only the non-existent elements at the tail + while (!(array[array.size()-1].exists())) + array.pop_back(); + return (array.size()); + } + + if (type==RSJ_LEAF) + return (1); + + return (0); + } + + // size + int size(void) { return (cleanup()); } +}; + + +// ------------------------------------------------------------ +// RSJresource member functions + +inline +RSJresource::~RSJresource (){ + if (parsed_data_p) delete parsed_data_p; +} + +inline +RSJresource::RSJresource (const RSJresource& r) { + data=r.data; + _exists = r._exists; + if(r.parsed_data_p) parsed_data_p = new RSJparsedData(*(r.parsed_data_p)); + else parsed_data_p = NULL; +} + +inline +RSJresource& RSJresource::operator= (const RSJresource& r) { + data=r.data; + _exists = r._exists; + if(r.parsed_data_p) parsed_data_p = new RSJparsedData(*(r.parsed_data_p)); + else parsed_data_p = NULL; + return *this; +} + +inline +int RSJresource::size (void) { + if (!exists()) return (0); + parse(); // parse if not parsed + return (parsed_data_p->size()); +} + +inline +RSJresourceType RSJresource::type (void) { + if (!exists()) return (RSJ_UNINITIATED); + parse(); // parse if not parsed + return (parsed_data_p->type); +} + +inline +std::string RSJresource::as_str (bool print_comments, bool update_data) { + if (exists()) { + std::string ret; + parse(); // parse if not parsed + parsed_data_p->cleanup(); + + if (parsed_data_p->type==RSJ_OBJECT) { + ret = "{\n"; + for (auto it=parsed_data_p->object.begin(); it!=parsed_data_p->object.end(); ++it) { + ret += RSJprinttab + "'" + it->first + "': " + insert_tab_after_newlines( it->second.as_str (print_comments, update_data) ); + if (std::next(it) != parsed_data_p->object.end()) ret += ","; + if (print_comments) + ret += " // " + to_string(it->second.type()); + ret += "\n"; + } + ret += "}"; + } + else if (parsed_data_p->type==RSJ_ARRAY) { + ret = "[\n"; + for (auto it=parsed_data_p->array.begin(); it!=parsed_data_p->array.end(); ++it) { + ret += RSJprinttab + insert_tab_after_newlines( it->as_str (print_comments, update_data) ); + if (std::next(it) != parsed_data_p->array.end()) ret += ","; + if (print_comments) + ret += " // " + to_string(it->type()); + ret += "\n"; + } + ret += "]"; + } + else // RSJ_LEAF or RSJ_UNKNOWN + ret = strtrim (data); + + if (update_data) data = ret; + return (ret); + } + else + return (""); +} + +// Parsers + +inline +RSJresourceType RSJresource::parse (bool force) { + if (!parsed_data_p) parsed_data_p = new RSJparsedData; + if (parsed_data_p->type==RSJ_UNKNOWN || force) parsed_data_p->parse (data, RSJ_UNKNOWN); + return (parsed_data_p->type); +} + +inline +void RSJresource::parse_full (bool force, int max_depth, int* parse_count_for_verbose_p) { // recursive parsing (slow) + if (max_depth==0) return; + if (!parsed_data_p) parsed_data_p = new RSJparsedData; + if (parsed_data_p->type==RSJ_UNKNOWN || force) parsed_data_p->parse (data, RSJ_UNKNOWN); + // verbose + if (parse_count_for_verbose_p) { + (*parse_count_for_verbose_p)++; + if ( (*parse_count_for_verbose_p) % 100 == 0) + std::cout << "parse_full: " << (*parse_count_for_verbose_p) << " calls." << std::endl; + } + // recursive parse children if not already parsed + if (parsed_data_p->type==RSJ_OBJECT) + for (auto it=parsed_data_p->object.begin(); it!=parsed_data_p->object.end(); ++it) + it->second.parse_full (force, max_depth-1, parse_count_for_verbose_p); + else if (parsed_data_p->type==RSJ_ARRAY) + for (auto it=parsed_data_p->array.begin(); it!=parsed_data_p->array.end(); ++it) + it->parse_full (force, max_depth-1, parse_count_for_verbose_p); +} + +// ------------------------------------------------------------ +// ============================================================ +// FAST PARSER (Under construction. DO NOT use the following functions in your application.) + +inline +int seek_next (std::string* str_p, int start_pos, char character) { + +} + +inline +void RSJresource::fast_parse (std::string* str_p, bool copy_string, int max_depth, int* parse_start_str_pos) { + // TODO: UNDER CONSTRUCTION... + + if (!str_p) + str_p = &data; + std::string& str = *str_p; + + // splits, while respecting brackets and escapes + //std::vector ret; + + //std::string current; + std::vector bracket_stack; + std::vector quote_stack; + bool escape_active = false; + int bi; + + bool initial_whitespaces = true; + bool isroot = false; + + if (!parse_start_str_pos) { + parse_start_str_pos = new int; + *parse_start_str_pos = 0; + isroot = true; + } + + int a = *parse_start_str_pos; + + while (*parse_start_str_pos < (int)str_p->length()) { // * + + // initial whitespace characters + if (initial_whitespaces) { + if (str[a] == ' ' || str[a] == '\n' || str[a] == '\r' || str[a] == '\t' ) { + ++a; + continue; + } + else { + if (str[a] == '{') // start of object + // ... TODO: seek_next ':' + + initial_whitespaces = false; + } + } + + + // delimiter + if ( bracket_stack.size()==0 && quote_stack.size()==0 && str[a]==RSJarraydelimiter ) { + //ret.push_back (current); + + //current.clear(); + bracket_stack.clear(); quote_stack.clear(); escape_active = false; + continue; // to * + } + + // ------------------------------------ + // checks for string + + if (quote_stack.size() > 0) { // already inside string + if (str[a]==RSJcharescape) // an escape character + escape_active = !escape_active; + else if (!escape_active && str[a]==RSJstringquotes[quote_stack.back()][1] ) { // close quote + quote_stack.pop_back(); + escape_active = false; + } + else + escape_active = false; + + //current.push_back (str[a]); + continue; // to * + } + + if (quote_stack.size()==0) { // check for start of string + if ((bi = is_bracket (str[a], RSJstringquotes)) >= 0) { + quote_stack.push_back (bi); + //current.push_back (str[a]); + continue; // to * + } + } + + // ------------------------------------ + // checks for comments + + if (quote_stack.size()==0) { // comment cannot start inside string + + // single-line commenst + if (str.compare (a, RSJlinecommentstart.length(), RSJlinecommentstart) == 0) { + // ignore until end of line + int newline_pos = str.find ("\n", a); + if (newline_pos == (int)std::string::npos) + newline_pos = str.find ("\r", a); + + if (newline_pos != (int)std::string::npos) + a = newline_pos; // point to the newline character (a will be incremented) + else // the comment continues until EOF + a = str.length(); + continue; + } + } + + // ------------------------------------ + // checks for brackets + + if ( bracket_stack.size()>0 && str[a]==RSJbrackets[bracket_stack.back()][1] ) { // check for closing bracket + bracket_stack.pop_back(); + //current.push_back (str[a]); + continue; + } + + if ((bi = is_bracket (str[a], RSJbrackets)) >= 0) { + bracket_stack.push_back (bi); + //current.push_back (str[a]); + continue; // to * + } + + // ------------------------------------ + // otherwise + //current.push_back (str[a]); + } + + /*if (current.length() > 0) + ret.push_back (current); */ + + if (isroot) + delete parse_start_str_pos; + + // return (ret); +} + +// ============================================================ + +// ------------------------------------------------------------ + +inline +RSJobject& RSJresource::as_object (bool force) { + if (!parsed_data_p) parsed_data_p = new RSJparsedData; + if (parsed_data_p->type==RSJ_UNKNOWN || force) parsed_data_p->parse (data, RSJ_OBJECT); + return (parsed_data_p->object); +} + +inline +RSJresource& RSJresource::operator[] (std::string key) { // returns reference + return ( (as_object())[key] ); // will return empty resource (with _exists==false) if + // either this resource does not exist, is not an object, or the key does not exist +} + +inline +RSJarray& RSJresource::as_array (bool force) { + if (!parsed_data_p) parsed_data_p = new RSJparsedData; + if (parsed_data_p->type==RSJ_UNKNOWN || force) parsed_data_p->parse (data, RSJ_ARRAY); + return (parsed_data_p->array); +} + +inline +RSJresource& RSJresource::operator[] (int indx) { // returns reference + as_array(); + if (indx >= (int)parsed_data_p->array.size()) + parsed_data_p->array.resize(indx+1); // insert empty resources + return (parsed_data_p->array[indx]); // will return empty resource (with _exists==false) if + // either this resource does not exist, is not an object, or the key does not exist +} + +// ------------------------------------------------------------ +// special 'as': + +template inline +vectorType RSJresource::as_vector (const vectorType& def) { // returns copy -- for being consistent with other 'as' specializations + if (!exists()) return (def); + vectorType ret; + as_array(); + for (auto it=parsed_data_p->array.begin(); it!=parsed_data_p->array.end(); ++it) + ret.push_back (it->as()); + return (ret); +} + +template inline +mapType RSJresource::as_map (const mapType& def) { // returns copy -- for being consistent with other 'as' specializations + if (!exists()) return (def); + mapType ret; + as_object(); + for (auto it=parsed_data_p->object.begin(); it!=parsed_data_p->object.end(); ++it) + ret[it->first] = it->second.as(); + return (ret); +} + +// ============================================================ +// Specialized .as() member functions + +// Helper preprocessor directives +#define rsjObject as() +#define rsjArray as() +#define rsjAs(t) as() + + +// RSJobject +template <> inline +RSJobject RSJresource::as (const RSJobject& def) { // returns copy -- for being consistent with other 'as' specializations + if (!exists()) return (def); + return (as_object()); +} + +// RSJarray +template <> inline +RSJarray RSJresource::as (const RSJarray& def) { // returns copy -- for being consistent with other 'as' specializations + if (!exists()) return (def); + return (as_array()); +} + +// ------------------------------------ +// Elementary types + +// String +template <> inline +std::string RSJresource::as (const std::string& def) { + if (!exists()) return (def); + + char qq = '\0'; + std::string ret = strip_outer_quotes (data, &qq); + + std::vector< std::vector > escapes = { {"\\n","\n"}, {"\\r","\r"}, {"\\t","\t"}, {"\\\\","\\"} }; + if (qq=='"') + escapes.push_back ({"\\\"","\""}); + else if (qq=='\'') + escapes.push_back ({"\\'","'"}); + + for (int a=0; a<(int)escapes.size(); ++a) + for ( std::size_t start_pos=ret.find(escapes[a][0]); start_pos!=(size_t)std::string::npos; start_pos=ret.find(escapes[a][0],start_pos) ) { + ret.replace (start_pos, escapes[a][0].length(), escapes[a][1]); + start_pos += escapes[a][1].length(); + } + + return (ret); +} + +// integer +template <> inline +int RSJresource::as (const int& def) { + if (!exists()) return (def); + return (atoi (strip_outer_quotes(data).c_str() ) ); +} + +// double +template <> inline +double RSJresource::as (const double& def) { + if (!exists()) return (def); + return (atof (strip_outer_quotes(data).c_str() ) ); +} + +// bool +template <> inline +bool RSJresource::as (const bool& def) { + if (!exists()) return (def); + std::string cleanData = strip_outer_quotes (data); + if (cleanData=="true" || cleanData=="TRUE" || cleanData=="True" || atoi(cleanData.c_str())!=0) return (true); + return (false); +} + +// ------------------------------------ +// Other types + +/*template <> template inline +bool RSJresource::as< std::vector > (const std::vector& def) { + return as_vector (def); +} + +template <> template inline +std::unordered_map RSJresource::as< std::unordered_map > + (const std::unordered_map& def) { + return as_map (def); +}*/ + +#endif \ No newline at end of file diff --git a/src/inference/protobuf/validate.h b/src/inference/protobuf/validate.h new file mode 100644 index 0000000000000000000000000000000000000000..ea16c07e316dc9a3dccf1f895fdba5d8afa16050 --- /dev/null +++ b/src/inference/protobuf/validate.h @@ -0,0 +1,467 @@ +#pragma once + +#include +#include "../enum/gm.h" + +#include "rsj.h" + +// START OF NAMESPACE +namespace util_protobuf { + +std::string convert_to_snake_case(const std::string &x) { + std::string o; + for (int i = 0; i < (int)x.length(); i++) { + if ((isalpha(x.at(i))) && (x.at(i) == toupper(x.at(i)))) { + o.push_back('_'); + o.push_back(tolower(x.at(i))); + } + else if (x.at(i) == ' ') { + o.push_back('_'); + } + else { + o.push_back(x.at(i)); + } + } + return o; +} + +template +void validate_protobuf_fields_inner(const T &x, RSJresource &raw_json) { + + const google::protobuf::Reflection* reflection = x.GetReflection(); + const google::protobuf::Descriptor* descriptor = x.GetDescriptor(); + + std::map key_map; + std::map key_map_rev; + for (auto &kv : raw_json.as_object()) { + auto c = convert_to_snake_case(kv.first); + key_map[c] = kv.first; + key_map_rev[kv.first] = c; + } + + for (auto &kv : raw_json.as_object()) { + if (descriptor->FindFieldByName(key_map_rev[kv.first]) == NULL) { + std::ostringstream buffer; + buffer << "PROTOBUF ERROR : " << "invalid field name = " << kv.first << " (" << key_map_rev[kv.first] << ")" << std::endl; + throw std::invalid_argument(buffer.str()); + } + } + for (int i=0; ifield_count(); i++) { + const google::protobuf::FieldDescriptor *fd = descriptor->field(i); + bool is_repeated = fd->is_repeated(); + int field_count = is_repeated ? reflection->FieldSize(x, fd) : 1; + if ((is_repeated) && (reflection->FieldSize(x, fd) != (int)raw_json[key_map[fd->name()]].as_array().size())) { + std::ostringstream buffer; + buffer << "PROTOBUF ERROR : " << "invalid repeated field value :: " << fd->name() << " = " << raw_json[key_map[fd->name()]].as() << std::endl; + throw std::invalid_argument(buffer.str()); + } + for (int index=0; indexHasField(x, fd)) && (raw_json[key_map[fd->name()]].exists())) { + std::ostringstream buffer; + buffer << "PROTOBUF ERROR : " << "invalid field value :: " << fd->name() << " = " << raw_json[key_map[fd->name()]].as() << std::endl; + throw std::invalid_argument(buffer.str()); + } + if (fd->type() == google::protobuf::FieldDescriptor::Type::TYPE_MESSAGE) { + if (is_repeated) { + validate_protobuf_fields_inner(reflection->GetRepeatedMessage(x,fd,index), raw_json[key_map[fd->name()]][index]); + } + else { + validate_protobuf_fields_inner(reflection->GetMessage(x,fd), raw_json[key_map[fd->name()]]); + } + } + } + } +} + +template +void validate_protobuf_fields(const T *x, std::string &s) { + RSJresource raw_json(s); + validate_protobuf_fields_inner(*x, raw_json); +} + +template +void validate_protobuf_inner(const T &x, bool ignore_internal) { + + const google::protobuf::Reflection* reflection = x.GetReflection(); + const google::protobuf::Descriptor* descriptor = x.GetDescriptor(); + + for (int i=0; ifield_count(); i++) { + const google::protobuf::FieldDescriptor *fd = descriptor->field(i); + const google::protobuf::FieldOptions opt = fd->options(); + google::protobuf::FieldDescriptor::Type ft = fd->type(); + + if ((fd->name().rfind("internal_", 0)) || (!ignore_internal)) { + + bool is_repeated = fd->is_repeated(); + int field_count = 1; + if (is_repeated) { + field_count = reflection->FieldSize(x, fd); + } + + for (int index=0; indexGetRepeatedFloat(x,fd,index); + } + else { + value = reflection->GetFloat(x,fd); + } + if ((value < minval) || (value > maxval)) { + std::ostringstream buffer; + buffer << "PROTOBUF ERROR : " << fd->name() << " not on range [" << minval << "," << maxval << ")."; + throw std::invalid_argument(buffer.str()); + } + } + } + else if (ft == google::protobuf::FieldDescriptor::Type::TYPE_INT32) { + int minval = opt.GetExtension(midi::minval); + int maxval = opt.GetExtension(midi::maxval); + if (minval == 0 && maxval == 0) { + // do nothing if not set + } + else { + int value; + if (is_repeated) { + value = reflection->GetRepeatedInt32(x,fd,index); + } + else { + value = reflection->GetInt32(x,fd); + } + if ((value < minval) || (value > maxval)) { + std::ostringstream buffer; + buffer << "PROTOBUF ERROR : " << fd->name() << " not on range [" << minval << "," << maxval << ")."; + throw std::invalid_argument(buffer.str()); + } + } + } + else if (ft == google::protobuf::FieldDescriptor::Type::TYPE_MESSAGE) { + if (is_repeated) { + validate_protobuf_inner( + reflection->GetRepeatedMessage(x,fd,index), ignore_internal); + } + else { + validate_protobuf_inner( + reflection->GetMessage(x,fd), ignore_internal); + } + } + } + } + } +} + + +// this function is validating the range of each variable +// if the field has a min and max defined +template +void validate_protobuf_field_ranges(T *x, bool ignore_internal=true) { + validate_protobuf_inner(*x, ignore_internal); +} + +bool operator< (const midi::Event &a, const midi::Event &b) { + if (a.time() != b.time()) { + return a.time() < b.time(); + } + if (std::min(a.velocity(),1) != std::min(b.velocity(),1)) { + return std::min(a.velocity(),1) < std::min(b.velocity(),1); + } + return a.pitch() < b.pitch(); +} + +void sort_piece_events(midi::Piece *p) { + + // find the re-indexing of the events using argsort + std::vector idx = arange(p->events_size()); + std::sort(idx.begin(), idx.end(), + [&p](size_t i, size_t j) {return p->events(i) < p->events(j);}); + + // make a map from old-index to new-index + std::map index_map; + int count = 0; + for (const auto &i : idx) { + index_map.insert(std::make_pair(i,count)); + count++; + } + + // replace the events in piece->events + midi::Piece orig(*p); + p->clear_events(); + for (const auto &i : idx) { + midi::Event *e = p->add_events(); + e->CopyFrom(orig.events(i)); + } + + // replace indices in piece->tracks + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::Track *t = p->mutable_tracks(track_num); + for (int bar_num=0; bar_numbars_size(); bar_num++) { + midi::Bar *b = t->mutable_bars(bar_num); + b->clear_events(); + std::vector bar_events; + for (const auto &e : orig.tracks(track_num).bars(bar_num).events()) { + bar_events.push_back( index_map[e] ); + } + std::sort(bar_events.begin(), bar_events.end()); + for (const auto &e : bar_events) { + b->add_events( e ); + } + } + } +} + +// 1. check that each event is within the bar +void validate_events(midi::Piece *p) { + for (const auto &track : p->tracks()) { + for (const auto &bar : track.bars()) { + int barlength = bar.internal_beat_length() * p->resolution(); + for (const auto &index : bar.events()) { + if ((index < 0) || (index >= p->events_size())) { + throw std::invalid_argument("EVENT INDEX IN BAR IS OUT OF RANGE!"); + } + int time = p->events(index).time(); + bool is_onset = (p->events(index).velocity()>0); + if ((time < 0) || ((time >= barlength) && (is_onset)) || ((time > barlength) && (!is_onset))) { + std::string event_type = "ONSET"; + if (!is_onset) { + event_type = "OFFSET"; + } + std::ostringstream buffer; + buffer << "NOTE " << event_type << " TIME (" << time << ") IS BEYOND EXTENTS OF BAR (" << barlength << ")"; + throw std::invalid_argument(buffer.str()); + } + } + } + } +} + +void check_track_lengths(midi::Piece *x) { + int num_tracks = x->tracks_size(); + if (num_tracks > 0) { + int num_bars = x->tracks(0).bars_size(); + for (int track_num=1; track_numtracks(track_num).bars_size()) { + throw std::invalid_argument("NUMBER OF BARS DIFFERS BETWEEN TRACKS!"); + } + } + } +} + +void check_time_sigs(midi::Piece *x) { + int track_num = 0; + std::vector numerators; + std::vector denominators; + for (const auto &track : x->tracks()) { + int bar_num = 0; + for (const auto &bar : track.bars()) { + if (track_num == 0) { + numerators.push_back( bar.ts_numerator() ); + denominators.push_back( bar.ts_denominator() ); + } + else { + if ((numerators[bar_num] != bar.ts_numerator()) || (denominators[bar_num] != bar.ts_denominator())) { + throw std::invalid_argument( + "TIME SIGNATURES FOR EACH BAR MUST BE THE SAME ACROSS ALL TRACKS."); + } + } + bar_num++; + } + track_num++; + } +} + +void set_beat_lengths(midi::Piece *x) { + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::Track *t = x->mutable_tracks(track_num); + for (int bar_num=0; bar_numbars_size(); bar_num++) { + midi::Bar *b = t->mutable_bars(bar_num); + b->set_internal_beat_length( + (double)b->ts_numerator() / b->ts_denominator() * 4); + } + } +} + + +void validate_piece(midi::Piece *x) { + + if (!x) { + throw std::invalid_argument("PIECE IS NULL. CANNOT VALIDATE!"); + } + + // check that piece has resolution + if (x->resolution() == 0) { + throw std::invalid_argument("PIECE RESOLUTION CAN NOT BE 0"); + } + + // validate range of fields + validate_protobuf_field_ranges(x); + + // set the beat length using the time signature information + set_beat_lengths(x); + + // to be kind we sort events + sort_piece_events(x); + + // check that there are the same number of bars in each track + check_track_lengths(x); + + // make sure time signatures are the same in each bar + check_time_sigs(x); + + // check events are valid + // event_index should reference valid event + // event times should be within each bar + validate_events(x); + +} + +// update has notes information +void prepare_piece(midi::Piece *x) { + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::Track *track = x->mutable_tracks(track_num); + for (int bar_num=0; bar_numbars_size(); bar_num++) { + } + } +} + +template +void check_range(T value, T minv, T maxv, const char *field) { + if ((value < minv) || (value >= maxv)) { + std::ostringstream buffer; + buffer << field << " not on range [" << minv << "," << maxv << ")."; + throw std::invalid_argument(buffer.str()); + } +} + +template +void check_all_same(std::set &values, const char *field) { + if ((int)values.size() != 1) { + std::ostringstream buffer; + buffer << field << " values must all be the same. {"; + for (const auto &val : values) { + buffer << val << ","; + } + buffer << "}"; + throw std::invalid_argument(buffer.str()); + } +} + +template +void check_all_different(std::set &values, int n, const char *field) { + if ((int)values.size() != n) { + std::ostringstream buffer; + buffer << field << " values must all be different."; + throw std::invalid_argument(buffer.str()); + } +} + +template +void check_in_domain(T value, std::set domain, const char *field) { + if (domain.find( value ) == domain.end()) { + std::ostringstream buffer; + buffer << field << " not in domain."; + throw std::invalid_argument(buffer.str()); + } +} + +int count_selected_bars(const midi::StatusTrack &track) { + int count = 0; + for (const auto &selected : track.selected_bars()) { + count += (int)selected; + } + return count; +} + +enum STATUS_TRACK_TYPE { + CONDITION, + RESAMPLE, + INFILL +}; + +STATUS_TRACK_TYPE infer_track_type(const midi::StatusTrack &track) { + int num_bars = track.selected_bars_size(); + int bar_count = count_selected_bars(track); + if (bar_count == 0) { + return CONDITION; + } + else if (bar_count != num_bars) { + return INFILL; + } + return RESAMPLE; +} + +void validate_param(midi::HyperParam *param) { + + validate_protobuf_field_ranges(param); + +} + +void validate_status(midi::Status *status, midi::Piece *piece, midi::HyperParam *param) { + + if ((!status) || (!piece)) { + throw std::invalid_argument("PIECE OR STATUS IS NULL. CANNOT VALIDATE!"); + } + + if (status->tracks_size() == 0) { + throw std::invalid_argument("STATUS IS EMPTY"); + } + + // validate range of fields + validate_protobuf_field_ranges(status); + + int track_num = 0; + for (const auto &track : status->tracks()) { + if (track.selected_bars_size() == 0) { + throw std::invalid_argument("NO SELECTED BARS"); + } + if (track.selected_bars_size() < param->model_dim()) { + throw std::invalid_argument("SELECTED BARS MUST BE ATLEAST MODEL_DIM"); + } + + // if track is conditioning it must be within range + STATUS_TRACK_TYPE tt = infer_track_type(track); + if ((tt == CONDITION) || (tt == INFILL)) { + check_range(track.track_id(), 0, piece->tracks_size(), "track_id"); + } + + // check that if resample all the bars are selected + if (track.autoregressive() == 1) { + for (const auto &bar : track.selected_bars()) { + if (!bar) { + throw std::invalid_argument("WHEN RESAMPLE IS ENABLED ALL BARS MUST BE SELECTED!"); + } + } + } + + // check that if ignore the mode is condition + if ((track.ignore()) && ((tt == INFILL) || (tt == RESAMPLE))) { + throw std::invalid_argument("CANNOT IGNORE TRACK WITH SELECTED BARS."); + } + + track_num++; + } + + // check track lengths and track_ids + std::set track_lengths; + std::set track_ids; + for (const auto &track : status->tracks()) { + track_lengths.insert( track.selected_bars_size() ); + track_ids.insert( track.track_id() ); + } + check_all_same(track_lengths, "sample_bars (length)"); + check_all_different(track_ids, status->tracks_size(), "track_id"); +} + + +void validate_inputs(midi::Piece *piece, midi::Status *status, midi::HyperParam *param) { + validate_piece(piece); + validate_status(status, piece, param); + validate_param(param); +} +} +// END OF NAMESPACE diff --git a/src/inference/random.h b/src/inference/random.h new file mode 100644 index 0000000000000000000000000000000000000000..5417546b8d02d04ec63749c257014fa82c2f570c --- /dev/null +++ b/src/inference/random.h @@ -0,0 +1,67 @@ +#pragma once + +#include + +template +std::vector arange(T start, T stop, T step = 1) { + std::vector values; + for (T value = start; value < stop; value += step) + values.push_back(value); + return values; +} + +template +std::vector arange(T stop) { + return arange(0, stop, 1); +} + +int random_on_range(int n, std::mt19937 *engine) { + std::uniform_int_distribution dist(0,n-1); + return dist(*engine); +} + +double random_on_unit(std::mt19937 *engine) { + std::uniform_real_distribution dist(0.,1.); + return dist(*engine); +} + +double random_on_range(double mi, double ma, std::mt19937 *engine) { + std::uniform_real_distribution dist(mi,ma); + return dist(*engine); +} + +int random_on_range(int mi, int ma, std::mt19937 *engine) { + if (mi > ma) { + std::ostringstream buffer; + buffer << "random_on_range: min=" << mi << " > max=" << ma; + throw std::invalid_argument(buffer.str()); + } + if (mi == ma) { + return mi; + } + std::uniform_int_distribution dist(mi,ma); + return dist(*engine); +} + +template +T random_element(std::vector &items, std::mt19937 *engine) { + int index = random_on_range(items.size(), engine); + return items[index]; +} + +int random_element_int(std::vector items, std::mt19937 *engine) { + int index = random_on_range(items.size(), engine); + return items[index]; +} + +template +std::vector random_subset(std::vector &items, std::mt19937 *engine) { + int n = random_on_range(items.size(), engine) + 1; + std::vector idx = arange((int)items.size()); + std::shuffle(idx.begin(), idx.end(), *engine); + std::vector output; + for (int i=0; i +#include + +namespace sampling { + + // Base class for callbacks + class CallbackBase { + public: + CallbackBase () { } + virtual ~CallbackBase () { } + virtual void on_bar_end () {} + virtual void on_prediction (std::vector &logits, int next_token) {} + virtual void on_start () {} + virtual float update_temperature(float current_temperature) { + return current_temperature; + } + virtual bool is_cancelled() { + return false; + } + }; + + // Class that manages call all callbacks + class CallbackManager { + public: + CallbackManager () {} + ~CallbackManager () {} + void add_callback_ptr(std::shared_ptr x) { + callbacks.push_back(x); + } + void on_bar_end () { + for (auto &x : callbacks) { + x->on_bar_end(); + } + } + void on_prediction (std::vector &logits, int next_token) { + for (auto &x : callbacks) { + x->on_prediction(logits, next_token); + } + } + void on_start () { + for (auto &x : callbacks) { + x->on_start(); + } + } + float update_temperature (float current_temperature) { + for (auto &x : callbacks) { + float value = x->update_temperature(current_temperature); + if (value > current_temperature) { + return value; + } + } + return current_temperature; + } + bool is_cancelled() { + for (auto &x : callbacks) { + if (x->is_cancelled()) { + return true; + } + } + return false; + } + std::vector> callbacks; + }; + + + // Callback examples + class TemperatureIncreaseCallback : public CallbackBase { + public: + TemperatureIncreaseCallback (float _increase, float _current_temperature) { + increase = _increase; + current_temperature = _current_temperature; + } + float update_temperature(float temp) { + current_temperature = temp + increase; + std::cout << "CURRENT TEMPERATURE : " << current_temperature << std::endl; + return current_temperature; + } + float increase; + float current_temperature; + }; + + + class LogLikelihoodCallback : public CallbackBase { + public: + LogLikelihoodCallback () { + loglik = 0; + sequence_length = 0; + } + void on_prediction(std::vector &logits, int next_token) { + loglik += logits[next_token]; + sequence_length++; + } + void on_start() { + loglik = 0; + sequence_length = 0; + } + double loglik; + int sequence_length; + }; + + class RecordTokenSequenceCallback : public CallbackBase { + public: + RecordTokenSequenceCallback () {} + void on_start() { + tokens.clear(); + } + void on_prediction(std::vector &logits, int next_token) { + tokens.push_back(next_token); + } + std::vector tokens; + }; + + class CancelCallback : public CallbackBase { + public: + CancelCallback () { + cancel = false; + } + void set_cancel(bool cancel_value) { + cancel = cancel_value; + } + bool is_cancelled() { + return cancel; + } + bool cancel; + }; + +} \ No newline at end of file diff --git a/src/inference/sampling/control.h b/src/inference/sampling/control.h new file mode 100644 index 0000000000000000000000000000000000000000..31f4ce59b4c5d3b91f37a19cf0e534cdf9d9cece --- /dev/null +++ b/src/inference/sampling/control.h @@ -0,0 +1,788 @@ +// this code tracks progress + +// constrain bars per track +// constrain timesteps per bar +// constrain max polyphony +// constrain offsets to notes that have been onset + +#pragma once + +#include +#include +#include +#include + +#include "../../common/encoder/representation.h" +#include "../../common/encoder/encoder_base.h" +#include "../enum/encoder_types.h" +#include "../../common/encoder/attribute_control.h" +#include "../../common/data_structures/verbosity.h" +#include "graph.h" + +namespace sampling { + +using TOKEN_EDGE = std::pair; +using CKPT_MAP_TYPE = std::map,std::tuple>; + +class CONDITIONAL_REP_GRAPH { +public: + + virtual ~CONDITIONAL_REP_GRAPH() {} + + virtual bool is_active(midi::TRACK_TYPE track_type) { + throw std::runtime_error("CONDITIONAL_REP_GRAPH::is_active() : NOT IMPLEMENTED"); + } + + template + void show(const std::vector &x, const std::string &s) { + std::ostringstream buffer; + buffer << s << " :: ["; + for (const auto &e : x) { + buffer << util_protobuf::enum_to_string(e) << ", "; + } + buffer << "]"; + data_structures::LOGGER(buffer.str()); + } + + midi::TOKEN_TYPE possibly_skip(midi::TRACK_TYPE track_type, int last_token, const std::unique_ptr &rg, std::shared_ptr &rep, std::vector & mask) { + + auto target_node = std::make_tuple(midi::TOKEN_NONE, 0); + + if (is_active(track_type)) { + auto inferred_node = rg->graph.infer_node(last_token, rg->enc); + auto next_nodes = graph->graph.get_next_nodes(inferred_node); + auto next_global_nodes = rg->graph.get_next_nodes(inferred_node); + + if ((next_nodes.size() == 1)) { + int loop_count = 0; + while ((next_global_nodes.size() == 1) && (next_nodes[0] != next_global_nodes[0])) { + target_node = next_global_nodes[0]; + next_global_nodes = rg->graph.get_next_nodes(target_node); + loop_count++; + + if (loop_count > 100) { + throw std::runtime_error("CONDITIONAL_REP_GRAPH::possibly_skip() : INFINITE LOOP"); + } + } + + if (std::get<0>(target_node) != midi::TOKEN_NONE) { + auto msg = data_structures::to_str("CONDITIONAL_REP_GRAPH::possibly_skip() : skip ", util_protobuf::enum_to_string(std::get<0>(target_node)), std::get<1>(target_node)); + std::cout << msg << std::endl; + data_structures::LOGGER(msg); + rg->graph.skip(rg->graph.get_previous_nodes(target_node)[0]); + rg->set_mask(rep->encode(std::get<0>(target_node), 0), mask); + + } + + } + } + return std::get<0>(target_node); + } + + std::unique_ptr graph; + +}; + + +class INSTRUMENT_CONDITIONAL_REP_GRAPH : public CONDITIONAL_REP_GRAPH { +public: + + INSTRUMENT_CONDITIONAL_REP_GRAPH(encoder::ENCODER *e, enums::MODEL_TYPE mt) { + graph = std::make_unique(e,mt,encoder::get_drum_exclusive_token_types()); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "INSTRUMENT_CONDITIONAL_REP_GRAPH" ); + } + + bool is_active(midi::TRACK_TYPE track_type) { + return (data_structures::is_drum_track(track_type) == false); + } + +}; + + +class DRUM_CONDITIONAL_REP_GRAPH : public CONDITIONAL_REP_GRAPH { +public: + + DRUM_CONDITIONAL_REP_GRAPH(encoder::ENCODER *e, enums::MODEL_TYPE mt) { + graph = std::make_unique(e,mt,encoder::get_instrument_exclusive_token_types()); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "DRUM_CONDITIONAL_REP_GRAPH" ); + } + + bool is_active(midi::TRACK_TYPE track_type) { + return data_structures::is_drum_track(track_type); + } + +}; + + +class SAMPLE_CONTROL { +public: + SAMPLE_CONTROL(midi::Piece *piece, midi::Status *status, midi::HyperParam *param, midi::ModelMetadata *meta) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "SAMPLE_CONTROL" ); + + verbose = param->verbose(); + + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, util_protobuf::protobuf_to_string(status)); + + initialize(piece, status, param, meta); + rep = enc->rep; + rg = std::make_unique(enc.get(), model_type); + instrument_rg = std::make_unique(enc.get(), model_type); + drum_rg = std::make_unique(enc.get(), model_type); + + if ((!rg) || (!instrument_rg) || (!drum_rg)) { + std::runtime_error("REP GRAPH CONSTRUCTOR FAILED"); + } + else { + data_structures::LOGGER( "REP GRAPH CONSTRUCTOR SUCCESS" ); + } + + parse_status(status); + initialize_members(); + + } + + ~SAMPLE_CONTROL() {} + + void initialize_members() { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "initialize_members" ); + barlength = 4 * enc->config->resolution; + timestep = 0; + absolute_timestep = 0; + bar_count = 0; + track_count = 0; + infill_bar_count = 0; + finished = false; + token_position = 0; + num_delta_tokens = 0; + } + + void set_bar_infill_prompt(std::vector> &bars, midi::Piece *p, midi::Status *status, midi::HyperParam *param) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "set_bar_infill_prompt" ); + + if (p) { + std::set> barset; + std::copy(bars.begin(), bars.end(), std::inserter(barset, barset.end())); + enc->config->do_multi_fill = true; + enc->config->multi_fill = barset; + + if (param->internal_skip_preprocess()) { + util_protobuf::calculate_note_durations(p); + prompt = enc->encode_wo_preprocess(p); + } + else { + prompt = enc->encode(p); + } + + data_structures::LOGGER( "FULL PROMPT " ); + for (int i=0; i<(int)prompt.size(); i++) { + data_structures::LOGGER( enc->rep->pretty(prompt[i]) ); + } + data_structures::LOGGER( "FULL PROMPT " ); + + int fill_start = enc->rep->encode(midi::TOKEN_FILL_IN_START,0); + for (int index=0; index<(int)prompt.size(); index++) { + if (prompt[index] == fill_start) { + prompt.resize(index+1); + break; + } + } + } + else { + throw std::runtime_error("MUST PROVIDE midi::Piece FOR BAR INFILL MODE"); + } + } + + void set_autoregressive_prompt(std::vector &tracks, midi::Piece *p, midi::Status *status, midi::HyperParam *param) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "set_autoregressive_prompt" ); + + enc->config->do_multi_fill = false; + + if (p->tracks_size()) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "SET AUTOREGRESSIVE PROMPT" ); + prompt = enc->encode(p); + } + else { + prompt.push_back( enc->rep->encode(midi::TOKEN_PIECE_START,0) ); + } + } + + void initialize(midi::Piece *piece, midi::Status *status, midi::HyperParam *param, midi::ModelMetadata *meta) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "initialize" ); + + util_protobuf::UpdateHasNotes(piece); + + std::vector tracks; + std::vector> bars; + int num_cond_tracks = 0; + int num_resample_tracks = 0; + int num_infill_tracks = 0; + std::vector track_types; + std::vector order; + std::vector cond_tracks; + + int track_num = 0; + for (const auto &track : status->tracks()) { + util_protobuf::STATUS_TRACK_TYPE tt = util_protobuf::infer_track_type(track); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("STATUS TRACK TYPE FOR ",track.track_id()," : ", tt)); + switch( tt ) { + case util_protobuf::CONDITION: + order.push_back( num_cond_tracks ); + cond_tracks.push_back( track.track_id() ); + num_cond_tracks++; + break; + case util_protobuf::RESAMPLE: + order.push_back( num_resample_tracks ); + tracks.push_back( track ); + num_resample_tracks++; + break; + case util_protobuf::INFILL : + num_infill_tracks++; + break; + } + track_types.push_back( tt ); + int bar_num = 0; + for (const auto &selected : track.selected_bars()) { + if (selected) { + bars.push_back( std::make_pair(track_num, bar_num) ); + } + bar_num++; + } + track_num++; + } + + // provide overview of tracks for sampling + int verbose_track_num = 0; + for (const auto &track_type : track_types) { + data_structures::LOGGER(data_structures::to_str("TRACK ", verbose_track_num, " -> ", track_type)); + verbose_track_num++; + } + + // select the correct model + int nb = status->tracks(0).selected_bars_size(); + + enc = enums::getEncoderFromString(meta->encoder()); + + if (num_infill_tracks > 0) { + data_structures::LOGGER( "INFILL" ); + model_type = enums::BAR_INFILL_MODEL; + + // remove excess bars if any + util_protobuf::prune_tracks( + piece, arange(0,piece->tracks_size(),1), arange(0,nb,1)); + + // here track ordering are preserved + inverse_order = arange(piece->tracks_size()); + data_structures::LOGGER(data_structures::to_str("GENERATING ", bars.size(), " BARS")); + set_bar_infill_prompt(bars, piece, status, param); + + } + else { + data_structures::LOGGER( "TRACK" ); + model_type = enums::TRACK_MODEL; + + data_structures::LOGGER(data_structures::to_str("GENERATING ", num_resample_tracks, " TRACKS")); + + // fix the order + // order is the output position for each track + for (track_num=0; track_numtracks_size(); track_num++) { + if (track_types[track_num] == util_protobuf::RESAMPLE) { + order[track_num] = order[track_num] + num_cond_tracks; + } + } + inverse_order.resize(order.size()); + for (int i=0; i<(int)order.size(); i++) { + inverse_order[order[i]] = i; + } + + + // prune unneeded tracks + util_protobuf::prune_tracks(piece, cond_tracks, arange(0,nb,1)); + + data_structures::LOGGER( "AFTER PRUNE TRACKS ...." ); + util_protobuf::print_piece_summary(piece); + data_structures::LOGGER( "============================" ); + + set_autoregressive_prompt(tracks, piece, status, param); + + } + } + + void finalize(midi::Piece *piece) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "finalize" ); + if (model_type == enums::TRACK_MODEL) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "Reordering tracks" ); + util_protobuf::reorder_tracks(piece, inverse_order); + } + } + + void parse_status(midi::Status *status) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "parse_status" ); + + // for bar-infilling we have to determine one thing + // 1) the number of bars to be infilled + int tnum = 0; + num_infill_bars = 0; + for (const auto &track : status->tracks()) { + int bnum = 0; + for (const auto &bar : track.selected_bars()) { + // keep track of selected bars + selected_bars.push_back( std::pair(tnum,bnum) ); + num_infill_bars += (int)bar; + bnum++; + } + tnum++; + } + + // for track generation we have to determine two things + // 1) the number of bars per track + // 2) the token restrictions for attribute control + // for this we can use a static mask for each track + // as there should be no overlap + num_bars = status->tracks(0).selected_bars_size(); + num_tracks = status->tracks_size(); + + for (int i=0; itracks_size(); i++) { + midi::StatusTrack track = status->tracks(inverse_order[i]); + std::vector mask(rep->max_token(),0); + + // add polyphony hard limit + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("TRACK: ", i, " - POLYPHONY HARD LIMIT: ", track.polyphony_hard_limit())); + polyphony_hard_limits.push_back( track.polyphony_hard_limit() ); + + // add per-track temperature + track_temperatures.push_back( track.temperature() ); + + // num bars + rep->set_mask(midi::TOKEN_NUM_BARS, {track.selected_bars_size()}, mask, 1); + + // track type + int tt = track.track_type(); + if (tt == midi::STANDARD_BOTH) { + rep->set_mask(midi::TOKEN_TRACK, {-1}, mask, 1); + } + else { + rep->set_mask(midi::TOKEN_TRACK, {tt}, mask, 1); + } + + // instrument + std::vector insts = enums::GM_MOD[track.instrument()]; + if ((enc->rep->has_pretrain_instrument_mapping()) && (tt == midi::STANDARD_TRACK)) { + // on drum tracks we don't map instruments + for (int i=0; i<(int)insts.size(); i++) { + auto it = enums::PRETRAIN_GROUPING_V2.find(insts[i]); + if (it != enums::PRETRAIN_GROUPING_V2.end()) { + insts[i] = it->second; + } + else { + throw std::runtime_error("CAN NOT FIND INSTRUMENT IN PRETRAIN GROUPING"); + } + } + } + rep->set_mask(midi::TOKEN_INSTRUMENT, insts, mask, 1); + + // density level + rep->set_mask(midi::TOKEN_DENSITY_LEVEL, {track.density()-1}, mask, 1); + + // min-max polyphony + rep->set_mask(midi::TOKEN_MIN_POLYPHONY, {track.min_polyphony_q()-1}, mask, 1); + rep->set_mask(midi::TOKEN_MAX_POLYPHONY, {track.max_polyphony_q()-1}, mask, 1); + + // min-max duration + rep->set_mask( + midi::TOKEN_MIN_NOTE_DURATION, {track.min_note_duration_q()-1}, mask, 1); + rep->set_mask( + midi::TOKEN_MAX_NOTE_DURATION, {track.max_note_duration_q()-1}, mask, 1); + + set_track_masks(rep, mask, &track); + + std::set fixed = { + midi::TOKEN_NUM_BARS, + midi::TOKEN_TRACK, + midi::TOKEN_GENRE, + midi::TOKEN_INSTRUMENT, + midi::TOKEN_TIME_SIGNATURE, + + midi::TOKEN_DENSITY_LEVEL, + midi::TOKEN_MIN_POLYPHONY, + midi::TOKEN_MAX_POLYPHONY, + midi::TOKEN_MIN_NOTE_DURATION, + midi::TOKEN_MAX_NOTE_DURATION, + + }; + + auto track_control_tokens = encoder::get_track_attribute_control_graph(); + for (int i=0; i<(int)track_control_tokens.size()-1; i++) { + if (fixed.find(track_control_tokens[i]) == fixed.end()) { + fixed.insert(track_control_tokens[i]); + } + } + + if (verbose) { + // show the attribute mask + data_structures::LOGGER( "=======================" ); + data_structures::LOGGER( "ATTRIBUTE MASK : " ); + for (int i=0; i<(int)mask.size(); i++) { + if (mask[i]) { + data_structures::LOGGER( rep->pretty(i) ); + } + } + data_structures::LOGGER( "=======================" ); + } + + for (const auto &kv : rep->token_domains) { + if (fixed.find(kv.first) == fixed.end()) { + rep->set_mask(kv.first, {-1}, mask, 1); + } + } + + attribute_masks.push_back( mask ); + + // changing to use status bar instead + std::vector> bar_masks; + for (int bn=0; bn bar_mask(mask); + midi::StatusBar bar = track.bars(bn); + if (rep->has_token_type(midi::TOKEN_TIME_SIGNATURE)) { + int tstoken = rep->encode(midi::TOKEN_TIME_SIGNATURE, std::make_tuple(bar.ts_numerator(), bar.ts_denominator())); + bar_mask[tstoken] = 1; // only allow time signature + } + set_bar_masks(rep, bar_mask, &bar); + bar_masks.push_back( bar_mask ); + } + attribute_bar_masks.push_back( bar_masks ); + + } + } + + void update(int token) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "controlhSAMPLECONTROL update" ); + midi::TOKEN_TYPE tt = rep->get_token_type(token); + switch (tt) { + case midi::TOKEN_TRACK: { + bar_count = 0; + absolute_timestep = 0; + bar_start_timestep = 0; + onsets.clear(); + note_expiry.clear(); + current_track_type = static_cast(rep->decode(token)); + break; + } + case midi::TOKEN_TRACK_END: { + track_count += 1; + break; + } + case midi::TOKEN_BAR: { + timestep = 0; + barlength = 4 * enc->config->resolution; + absolute_timestep = bar_start_timestep; + break; + } + case midi::TOKEN_BAR_END: { + bar_count += 1; + bar_start_timestep += barlength; + break; + } + case midi::TOKEN_FILL_IN_START: { + // clear onsets and read in the token sequence + // we backfill events in between FILL_IN_PLACEHOLDERs + // so that we don't skip context + int fillp_token = enc->rep->encode(midi::TOKEN_FILL_IN_PLACEHOLDER,0); + auto it = history.begin(); + auto prev = it; + for (int i=0; i<=infill_bar_count; i++) { + prev = it; + it = find(next(it), history.end(), fillp_token); + } + for (auto i=next(prev); i!=it; i++) { + if (verbose) { + data_structures::LOGGER(data_structures::to_str("BACKFILLING :: ", enc->rep->pretty(*i))); + } + update(*i); + } + break; + } + case midi::TOKEN_FILL_IN_END: { + infill_bar_count += 1; + break; + } + case midi::TOKEN_TIME_SIGNATURE: { + std::tuple ts = rep->decode_timesig(token); + double ts_ratio = ((double)std::get<0>(ts) / std::get<1>(ts)); + barlength = ts_ratio * 4 * enc->config->resolution; + break; + } + case midi::TOKEN_TIME_ABSOLUTE_POS: { + int t = rep->decode(token); + timestep = t; + absolute_timestep = bar_start_timestep + t; + break; + } + case midi::TOKEN_DELTA: { + break; + } + case midi::TOKEN_NOTE_ONSET: { + int pitch = rep->decode(token); + onsets.insert( pitch ); + + if (data_structures::is_drum_track(current_track_type)) { + // artificially add the note duration of 1 + last_token = token; + update(rep->encode(midi::TOKEN_NOTE_DURATION,0)); + } + + break; + } + case midi::TOKEN_NOTE_DURATION: { + int dur = rep->decode(token) + 1; + int pitch = rep->decode(last_token); + note_expiry[dur + absolute_timestep].push_back(pitch); + break; + } + default: + break; + } + + // remove notes that have "expired" + std::vector to_remove; + for (const auto &kv : note_expiry) { + if (kv.first <= absolute_timestep) { + for (const auto &pitch : kv.second) { + onsets.erase( pitch ); + } + to_remove.push_back( kv.first ); + } + } + // remove these lists from note expiry + for (const auto &t : to_remove) { + note_expiry.erase( t ); + } + + last_token = token; + + if (verbose) { + data_structures::LOGGER(data_structures::to_str("ONSETS : ", onsets.size())); + } + + + } + + void set_mask(int last_token, std::vector &mask) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "controlhSAMPLECONTROL set_mask" ); + + // basic constraints of the representation + midi::TOKEN_TYPE last_tt = rep->get_token_type(last_token); + bool is_drum = data_structures::is_drum_track(current_track_type); + + // automatically handle skipping tokens when necessary for drum or instrument tracks + midi::TOKEN_TYPE inst_skip = instrument_rg->possibly_skip(current_track_type, last_token, rg, rep, mask); + midi::TOKEN_TYPE drum_skip = drum_rg->possibly_skip(current_track_type, last_token, rg, rep, mask); + + if ((is_drum) && (last_tt == midi::TOKEN_NOTE_ONSET)) { + // fast forward past NOTE_DURATION token + rg->skip(midi::TOKEN_NOTE_ONSET); + rg->set_mask(rep->encode(midi::TOKEN_NOTE_DURATION,0), mask); + } + else if ((inst_skip == midi::TOKEN_NONE) && (drum_skip == midi::TOKEN_NONE)) { + rg->set_mask(last_token, mask); + } + + // can't have onset for note that is already sounding + for (const auto &pitch : onsets) { + mask[rep->encode(midi::TOKEN_NOTE_ONSET,pitch)] = 0; + } + + // can't have note onsets when timestep == barlength + // you can only have note offsets + if (timestep == barlength) { + if (verbose) { + data_structures::LOGGER( "HIT TIME LIMIT >>>>>>>>>>>>>>>>>>>> " ); + } + rep->set_mask(midi::TOKEN_NOTE_ONSET, {-1}, mask, 0); + rep->set_mask(midi::TOKEN_VELOCITY_LEVEL, {-1}, mask, 0); + } + + // determine what the hard limit is + // the hard limit may be smaller when a limited domain + int hard_limit = 0; + if (model_type == enums::TRACK_MODEL) { + int index = std::min(track_count, num_tracks-1); + hard_limit = polyphony_hard_limits[index]; + } + else { + int index = std::min(infill_bar_count, num_infill_bars-1); + hard_limit = polyphony_hard_limits[std::get<0>(selected_bars[index])]; + } + + // can't have more than n simultaneous notes + if ((int)onsets.size() >= hard_limit) { + data_structures::LOGGER(data_structures::to_str("HIT HARD LIMIT ( ",(int)onsets.size()," >= ", hard_limit, " ) >>>>>>>>>>>>>>>>>>>> ")); + rep->set_mask(midi::TOKEN_NOTE_ONSET, {-1}, mask, 0); + rep->set_mask(midi::TOKEN_VELOCITY_LEVEL, {-1}, mask, 0); + // will be ignored if token doesn't exist + } + + //Check if microtiming is allowed + int delta_domain_limit = rep->get_domain_size(midi::TOKEN_DELTA); + + if (!enc->config->use_microtiming) { + for (int td=0; td ","MASKING DELTA :: ", td)); + mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; + } + } else { + if (last_tt == midi::TOKEN_DELTA) { + num_delta_tokens += 1; + mask[rep->encode(midi::TOKEN_DELTA_DIRECTION,0)] = 0; + } else { + num_delta_tokens = 0; + } + //Check if max number microtiming tokens achieved + if (num_delta_tokens > 0) { + for (int td=0; td ","MASKING DELTA :: ", td)); + mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; + } + } + + //Keep delta within the bar when forward + if (delta_domain_limit) { + int max_step = enc->config->step_to_delta(barlength - timestep, enc->config->resolution); + if ((last_tt == midi::TOKEN_DELTA) && (rep->decode(last_token) == 0)){ + max_step = enc->config->step_to_delta(timestep, enc->config->resolution); + } + int max_td = std::max(std::min(max_step, delta_domain_limit), 0); + for (int td=max_td; td ","MASKING DELTA :: ", td)); + mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; + } + } + + //Forward delta only at start of bar + if (timestep == 0) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("AT START -> ","MASKING DELTA DIRECTION")); + mask[rep->encode(midi::TOKEN_DELTA_DIRECTION,0)] = 0; + } + + //Backward delta only at end of bar + if (timestep == barlength) { + for (int td=1; td ","MASKING DELTA :: ", 0)); + mask[rep->encode(midi::TOKEN_DELTA,td)] = 0; + } + } + } + + // also restrict this for absolute time + // but this should be limited based on time_signature / barlength + int domain_limit = rep->get_domain_size(midi::TOKEN_TIME_ABSOLUTE_POS); + if (domain_limit) { + for (int td=0; td<=timestep; td++) { + mask[rep->encode(midi::TOKEN_TIME_ABSOLUTE_POS,td)] = 0; + } + for (int td=barlength+1; tdencode(midi::TOKEN_TIME_ABSOLUTE_POS,td)] = 0; + } + } + + if (model_type == enums::TRACK_MODEL) { + // limit number of bars + if (bar_count != num_bars) { + rep->set_mask(midi::TOKEN_TRACK_END, {-1}, mask, 0); + } + else { + rep->set_mask(midi::TOKEN_BAR, {-1}, mask, 0); + } + // limit the track count + if (track_count >= num_tracks) { + std::fill(mask.begin(), mask.end(), 0); + finished = true; + } + // only add attribute mask if not finished + // otherwise it will crash with track_count out of range + if (!finished) { + for (int i=0; i<(int)mask.size(); i++) { + int num_bars = attribute_bar_masks[track_count].size(); + int safe_bar_index = std::min(bar_count, num_bars - 1); + mask[i] *= attribute_bar_masks[track_count][safe_bar_index][i]; + } + } + } + else if ((model_type == enums::BAR_INFILL_MODEL)) { + // limit the bar infill count + if (infill_bar_count >= num_infill_bars) { + std::fill(mask.begin(), mask.end(), 0); + finished = true; + } + } + + // if mask is all zeros we have a problem as the model has + // no 'valid' path forward + if ((std::find(mask.begin(), mask.end(), 1) == mask.end()) && (!finished)) { + throw std::runtime_error("FATAL ERROR : EVERY TOKEN IS MASKED"); + } + + } + + std::vector get_mask(std::vector &tokens) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "get_mask" ); + std::vector mask(enc->rep->max_token(), 0); + for (int t=token_position; t<(int)tokens.size(); t++) { + if (verbose) { + data_structures::LOGGER(data_structures::to_str("UPDATING [", token_position, "] :: ", enc->rep->pretty(tokens[t]))); + } + update( tokens[t] ); + history.push_back( tokens[t] ); + token_position++; + } + + set_mask(tokens.back(), mask); + return mask; + } + + // map from pitch to when it expires + std::map> note_expiry; // time -> list of pitches + std::set onsets; + int last_token; + int num_delta_tokens; + midi::TRACK_TYPE current_track_type; + + int barlength; + int timestep; + int absolute_timestep; + int bar_start_timestep; + + int bar_count; + int track_count; + int infill_bar_count; + + int num_bars; + int num_tracks; + int num_infill_bars; + std::vector> attribute_masks; + std::vector>> attribute_bar_masks; + + std::vector prompt; + std::vector inverse_order; + std::string ckpt_path; + + int token_position; + bool finished; + enums::MODEL_TYPE model_type; + std::vector history; + + bool verbose; + int polyphony_hard_limit; + std::vector polyphony_hard_limits; + std::vector track_temperatures; + std::vector> selected_bars; + + std::unique_ptr enc; + std::shared_ptr rep; + std::unique_ptr rg; + std::unique_ptr instrument_rg; + std::unique_ptr drum_rg; + +}; + + +} \ No newline at end of file diff --git a/src/inference/sampling/graph.h b/src/inference/sampling/graph.h new file mode 100644 index 0000000000000000000000000000000000000000..a7880f167bcb12c7a25a46db7462d13663d21d7a --- /dev/null +++ b/src/inference/sampling/graph.h @@ -0,0 +1,381 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace sampling { + +// define printing methods +std::ostream& operator<<(std::ostream& os, const std::tuple &obj) { + return os << "(" << util_protobuf::enum_to_string(std::get<0>(obj)) << "," << std::get<1>(obj) << ")"; +} + +std::string toString(const std::tuple &obj) { + return std::string("(") + util_protobuf::enum_to_string(std::get<0>(obj)) + "," + std::to_string(std::get<1>(obj)) + ")"; +} + +template +class DIGRAPH_NODE { +public: + DIGRAPH_NODE(const T &x) { + node_id = x; + } + std::set edges; + std::set in_edges; + T node_id; +}; + +template +class DIGRAPH { +public: + DIGRAPH() { + traversal_started = false; + } + DIGRAPH(const std::vector> &paths) { + traversal_started = false; + build_from_paths(paths); + } + void remove_edges_to_node(const T &v) { + data_structures::LOGGER(data_structures::to_str("REMOVING EDGES TO NODE ", toString(v))); + for (auto kv : nodes) { + nodes.find(kv.first)->second.edges.erase(v); + nodes.find(kv.first)->second.in_edges.erase(v); + } + } + void remove_node(const T &v) { + data_structures::LOGGER(data_structures::to_str("REMOVING NODE ", toString(v))); + auto node = nodes.find(v); + if (node != nodes.end()) { + std::set out_edges = node->second.edges; + for (const auto &pre : node->second.in_edges) { + for (const auto &e : out_edges) { + nodes.find(pre)->second.edges.insert( e ); + } + nodes.find(pre)->second.edges.erase( v ); + } + std::set in_edges = node->second.in_edges; + for (const auto &post : node->second.edges) { + for (const auto &e : in_edges) { + nodes.find(post)->second.in_edges.insert( e ); + } + nodes.find(post)->second.in_edges.erase( v ); + } + remove_edges_to_node(v); + nodes.erase(v); + } + } + void remove_nodes(const std::vector &vs) { + for (const auto &v : vs) { + remove_node(v); + } + } + void remove_nodes_wo_connecting(const std::vector &vs) { + for (const auto &v : vs) { + remove_edges_to_node(v); + nodes.erase(v); + } + } + void add_node(const T &v) { + if (nodes.find(v) == nodes.end()) { + data_structures::LOGGER(data_structures::to_str("ADDING NODE ", toString(v))); + nodes.insert( std::make_pair(v,DIGRAPH_NODE(v)) ); + } + } + void add_edge(const T &u, const T &v) { + data_structures::LOGGER(data_structures::to_str("ADDING EDGE ", toString(u), " -> ", toString(v))); + add_node(u); + add_node(v); + nodes.find(u)->second.edges.insert(v); + nodes.find(v)->second.in_edges.insert(u); + } + void add_path(const std::vector &path) { + for (int i=0; i<(int)path.size()-1; i++) { + add_edge(path[i], path[i+1]); + } + } + void build_from_paths(const std::vector> &paths) { + for (const auto &path : paths) { + add_path(path); + } + } + bool check_path(const T &u, const T &v, int depth, int max_depth) { + if (depth < max_depth) { + auto choices = get_next_nodes(u); + for (const auto &e : choices) { + if (e == v) { + return true; + } + } + for (const auto &e : choices) { + if (check_path(e, v, depth + 1, max_depth)) { + return true; + } + } + } + return false; + } + std::vector get_previous_nodes(const T &v) { + auto it = nodes.find(v); + if (it == nodes.end()) { + std::ostringstream buffer; + buffer << "ERROR : INVALID NODE IN DIGRAPH (" << v << ")"; + throw std::runtime_error(buffer.str()); + } + std::vector previous_tokens; + for (const auto &e : it->second.in_edges) { + previous_tokens.push_back(e); + } + return previous_tokens; + } + std::vector get_next_nodes(const T &v) { + auto it = nodes.find(v); + if (it == nodes.end()) { + std::ostringstream buffer; + buffer << "ERROR [get_next_nodes()] : INVALID NODE IN DIGRAPH (" << v << ")"; + throw std::runtime_error(buffer.str()); + } + std::vector next_tokens; + for (const auto &e : it->second.edges) { + next_tokens.push_back(e); + } + return next_tokens; + } + T infer_node(const int &last_token, encoder::ENCODER *enc) { + data_structures::LOGGER(data_structures::to_str("INFERRING NODE FROM TOKEN ", (enc->rep->pretty(last_token)))); + data_structures::LOGGER(data_structures::to_str("CURRENT NODE ", toString(current_node))); + std::vector next_nodes = get_next_nodes(current_node); + midi::TOKEN_TYPE tt = enc->rep->get_token_type(last_token); + if (tt == midi::TOKEN_PIECE_START) { + return std::make_tuple(midi::TOKEN_PIECE_START,0); // special case at the start of the token sequence + } + if (next_nodes.size() == 0) { + std::ostringstream buffer; + buffer << "ERROR : NO NEXT TOKENS IN DIGRAPH (" << current_node << ")"; + throw std::runtime_error(buffer.str()); + } + if (next_nodes.size() == 1) { + return next_nodes[0]; + } + for (const auto &e : next_nodes) { + if (std::get<0>(e) == tt) { + return e; + } + } + throw std::runtime_error("ERROR : CANNOT INFER NODE"); + } + + /// some code to visualize a graph in graphviz format + void print_graphviz() { + std::cout << "digraph G {" << std::endl; + for (const auto &kv : nodes) { + for (const auto &e : kv.second.edges) { + std::cout << kv.first << " -> " << e << std::endl; + } + } + std::cout << "}" << std::endl; + } + + // handle graph traversal + void traverse(const T &node) { + if (traversal_started) { + if (!check_path(current_node, node, 0, 1)) { + std::ostringstream buffer; + buffer << "ERROR : INVALID PATH IN DIGRAPH (" << current_node << " --> " << node << ")"; + throw std::runtime_error(buffer.str()); + } + } + current_node = node; + traversal_started = true; + } + void skip(const T &node) { + if (!traversal_started) { + throw std::runtime_error("ERROR : CANNOT SKIP BEFORE TRAVERSAL STARTED"); + } + if (!check_path(current_node, node, 0, 20)) { + std::ostringstream buffer; + buffer << "ERROR : INVALID PATH IN DIGRAPH (" << current_node << " --> ... -->" << node << ")"; + throw std::runtime_error(buffer.str()); + } + current_node = node; + } + + std::map> nodes; + T current_node; + bool traversal_started; +}; + +// how to handle graph traversal + +using NODE_TYPE = std::tuple; + +std::vector> DEF_GRAPH = { + {midi::TOKEN_PIECE_START, midi::TOKEN_NUM_BARS, midi::TOKEN_TRACK}, + {midi::TOKEN_TIME_SIGNATURE, midi::TOKEN_TIME_ABSOLUTE_POS}, + {midi::TOKEN_TIME_SIGNATURE, midi::TOKEN_VELOCITY_LEVEL}, + {midi::TOKEN_TIME_SIGNATURE, midi::TOKEN_FILL_IN_PLACEHOLDER}, + {midi::TOKEN_FILL_IN_PLACEHOLDER, midi::TOKEN_BAR_END}, + {midi::TOKEN_VELOCITY_LEVEL, midi::TOKEN_NOTE_ONSET}, + {midi::TOKEN_VELOCITY_LEVEL, midi::TOKEN_DELTA}, + {midi::TOKEN_DELTA_DIRECTION, midi::TOKEN_DELTA}, + {midi::TOKEN_DELTA, midi::TOKEN_DELTA}, + {midi::TOKEN_DELTA, midi::TOKEN_DELTA_DIRECTION}, + {midi::TOKEN_DELTA, midi::TOKEN_NOTE_ONSET}, + {midi::TOKEN_DELTA, midi::TOKEN_FILL_IN_END}, + {midi::TOKEN_NOTE_ONSET, midi::TOKEN_NOTE_DURATION}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_TIME_ABSOLUTE_POS}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_NOTE_ONSET}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_VELOCITY_LEVEL}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_BAR_END}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_FILL_IN_END}, + {midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_NOTE_ONSET}, + {midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_VELOCITY_LEVEL}, + {midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_BAR_END}, + {midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_FILL_IN_END}, + {midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_DELTA}, + {midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_DELTA_DIRECTION}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_DELTA}, + {midi::TOKEN_NOTE_DURATION, midi::TOKEN_DELTA_DIRECTION}, + {midi::TOKEN_DELTA_DIRECTION, midi::TOKEN_DELTA}, + {midi::TOKEN_BAR_END, midi::TOKEN_BAR}, + {midi::TOKEN_BAR_END, midi::TOKEN_TRACK_END}, + {midi::TOKEN_TRACK_END, midi::TOKEN_TRACK}, + {midi::TOKEN_TRACK_END, midi::TOKEN_FILL_IN_START}, + {midi::TOKEN_FILL_IN_START, midi::TOKEN_TIME_ABSOLUTE_POS}, + {midi::TOKEN_FILL_IN_START, midi::TOKEN_VELOCITY_LEVEL}, + {midi::TOKEN_FILL_IN_END, midi::TOKEN_FILL_IN_START}, +}; + +template +std::vector> convert(const std::vector &xs) { + std::vector> ys; + for (const auto &x : xs) { + ys.push_back(x); + } + return ys; +} + +template +std::vector> convert(std::vector &xs, std::tuple defaults) { + std::vector> ys; + for (auto x : xs) { + ys.push_back(std::tuple_cat(std::make_tuple(x), defaults)); + } + return ys; +} + +template +std::vector convert(std::vector> &xs) { + std::vector ys; + for (auto x : xs) { + ys.push_back(std::get(x)); + } + return ys; +} + +class REP_GRAPH { +public: + REP_GRAPH(encoder::ENCODER *e, enums::MODEL_TYPE mt) { + enc = e; + initialize(mt, {}); + } + REP_GRAPH(encoder::ENCODER *e, enums::MODEL_TYPE mt, std::vector> tokens_to_remove) { + enc = e; + initialize(mt, tokens_to_remove); + } + void initialize(enums::MODEL_TYPE mt, std::vector> tokens_to_remove) { + for (auto &path : DEF_GRAPH) { + graph.add_path(convert(path, std::tuple(0))); + } + graph.add_path(encoder::get_bar_attribute_control_graph_v2()); + graph.add_path(encoder::get_track_attribute_control_graph_v2()); + graph.add_path(encoder::get_track_pre_instrument_attribute_control_graph_v2()); + std::vector to_remove; + auto just_tokens = convert<0,midi::TOKEN_TYPE>(tokens_to_remove); + for (const auto &kv : graph.nodes) { + if (!enc->rep->has_token_type(std::get<0>(kv.first))) { + to_remove.push_back(kv.first); + } + if (std::find(just_tokens.begin(), just_tokens.end(), std::get<0>(kv.first)) != just_tokens.end()) { + to_remove.push_back(kv.first); + } + } + graph.remove_nodes(to_remove); + + if (mt == enums::TRACK_MODEL) { + initialize_autoregressive(); + } + else { + initialize_bar_infilling(); + } + } + void initialize_bar_infilling() { + std::set to_keep = { + midi::TOKEN_VELOCITY_LEVEL, + midi::TOKEN_NOTE_ONSET, + midi::TOKEN_NOTE_DURATION, + midi::TOKEN_DELTA, + midi::TOKEN_DELTA_DIRECTION, + midi::TOKEN_TIME_ABSOLUTE_POS, + midi::TOKEN_FILL_IN_START, + midi::TOKEN_FILL_IN_END, + }; + std::vector to_remove; + for (const auto &kv : graph.nodes) { + if (to_keep.find(std::get<0>(kv.first)) == to_keep.end()) { + to_remove.push_back(kv.first); + } + } + graph.remove_nodes_wo_connecting(to_remove); + graph.current_node = std::make_tuple(midi::TOKEN_FILL_IN_END,0); // necessary for graph traversal + } + + void initialize_autoregressive() { + std::vector to_remove = { + midi::TOKEN_FILL_IN_PLACEHOLDER, + midi::TOKEN_FILL_IN_START, + midi::TOKEN_FILL_IN_END + }; + graph.remove_nodes_wo_connecting(convert(to_remove, std::tuple(0))); + } + void set_mask(int last_token, std::vector &mask) { + auto tt = graph.infer_node(last_token, enc); + data_structures::LOGGER(data_structures::to_str("GRAPH INFERENCE : ", toString(tt))); + graph.traverse(tt); // validate graph traversals + for (const auto &e : graph.get_next_nodes(tt)) { + enc->rep->set_mask(std::get<0>(e), {-1}, mask, 1); + } + } + + // helpers + std::vector get_next_nodes(midi::TOKEN_TYPE ttt) { + NODE_TYPE tt = std::make_tuple(ttt, 0); + std::vector next_nodes = graph.get_next_nodes(tt); + std::vector next_tokens; + for (const auto &e : next_nodes) { + next_tokens.push_back(std::get<0>(e)); + } + return next_tokens; + } + + std::vector get_previous_nodes(midi::TOKEN_TYPE ttt) { + NODE_TYPE tt = std::make_tuple(ttt, 0); + std::vector next_nodes = graph.get_previous_nodes(tt); + std::vector next_tokens; + for (const auto &e : next_nodes) { + next_tokens.push_back(std::get<0>(e)); + } + return next_tokens; + } + + void skip(const midi::TOKEN_TYPE &t) { + graph.skip(std::make_tuple(t, 0)); + } + + encoder::ENCODER *enc; + DIGRAPH graph; +}; + +} \ No newline at end of file diff --git a/src/inference/sampling/multi_step.h b/src/inference/sampling/multi_step.h new file mode 100644 index 0000000000000000000000000000000000000000..755bf06a1fe399943ffa2b34f9b729ac4f0c643b --- /dev/null +++ b/src/inference/sampling/multi_step.h @@ -0,0 +1,536 @@ +#include +#include +#include +#include + +template +using vmatrix = std::vector>; + +template +class cmatrix { +public: + cmatrix() { + N = 0; + M = 0; + m_data = vmatrix(0, std::vector(0, 0)); + } + cmatrix(int n, int m, int value) { + N = n; + M = m; + m_data = vmatrix(n, std::vector(m, value)); + } + cmatrix(const cmatrix &x) { + N = x.N; + M = x.M; + m_data = vmatrix(N, std::vector(M, 0)); + for (int i=0; i & operator=(const cmatrix &x) { + N = x.N; + M = x.M; + m_data = vmatrix(N, std::vector(M, 0)); + for (int i=0; i> &x) { + assert(x.size() > 0); + N = x.size(); + M = x[0].size(); + m_data = vmatrix(N, std::vector(M, 0)); + for (int i=0; i &x) { + return (N == x.N) && (M == x.M); + } + bool is_shape(int n, int m) { + return (N == n) && (M == m); + } + cmatrix transpose() { + cmatrix c(M, N, 0); + for (int i=0; i m_data; +}; + + +template +vmatrix random_boolean_matrix(int n, int m, double p, std::mt19937 *e) { + std::uniform_real_distribution dist(0.0, 1.0); + vmatrix x(n, std::vector(m, 0)); + for (int i=0; i +vmatrix ones(int n, int m) { + return vmatrix(n, std::vector(m, 1)); +} + +template +vmatrix zeros(int n, int m) { + return vmatrix(n, std::vector(m, 0)); +} + +template +vmatrix binary_op(const vmatrix &a, const vmatrix &b, F &&func) { + assert(a.size() == b.size()); + vmatrix c(a.size(), std::vector(a[0].size(), 0)); + for (int i=0; i<(int)a.size(); i++) { + assert(a[i].size() == b[i].size()); + for (int j=0; j<(int)a[i].size(); j++) { + c[i][j] = func(a[i][j], b[i][j]); + } + } + return c; +} + +template +vmatrix operator& (const vmatrix &a, const vmatrix &b) { + return binary_op(a, b, [](T x, T y) { return x & y; }); +} + +template +vmatrix operator| (const vmatrix &a, const vmatrix &b) { + return binary_op(a, b, [](T x, T y) { return x | y; }); +} + + +template +bool all(const vmatrix &x) { + for (const auto &row : x) { + for (const auto &elem : row) { + if (!elem) { + return false; + } + } + } + return true; +} + +template +bool all(const cmatrix &x) { + return all(x.m_data); +} + +template +bool any(const std::vector &x) { + for (const auto &elem : x) { + if (elem) { + return true; + } + } + return false; +} + +template +bool any(const vmatrix &x) { + for (const auto &row : x) { + for (const auto &elem : row) { + if (elem) { + return true; + } + } + } + return false; +} + +template +bool any(const cmatrix &x) { + return any(x.m_data); +} + +template +int sum(const vmatrix &x) { + int total = 0; + for (const auto &row : x) { + for (const auto &elem : row) { + total += (int)elem; + } + } + return total; +} + +int sum(const cmatrix &x) { + return sum(x.m_data); +} + +template +bool equal(const vmatrix &a, const vmatrix &b) { + if(a.size() != b.size()) { + return false; + } + for (int i=0; i<(int)a.size(); i++) { + if(a[i].size() != b[i].size()) { + return false; + } + for (int j=0; j<(int)a[i].size(); j++) { + if (a[i][j] != b[i][j]) { + return false; + } + } + } + return true; +} + +template +bool operator==(const vmatrix &a, const vmatrix &b) { + return equal(a, b); +} + +template +void show(const vmatrix &x) { + for (const auto &row : x) { + for (const auto &elem : row) { + std::cout << elem << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template +void show(const cmatrix &x) { + show(x.m_data); +} + +template +void show(const std::vector &x) { + for (const auto &elem : x) { + std::cout << elem << " "; + } + std::cout << std::endl; +} + +template +T clamp(T x, T min, T max) { + return std::max(std::min(x, max), min); +} + +template +cmatrix unary_op(cmatrix a, F &&func) { + cmatrix c(a.N, a.M, 0); + for (int i=0; i +cmatrix binary_op(cmatrix a, cmatrix b, F &&func) { + assert(a.same_shape(b)); + cmatrix c(a.N, a.M, 0); + for (int i=0; i +cmatrix operator~(cmatrix a) { + return unary_op(a, [](T x) { return !x; }); +} + +template +cmatrix operator&(cmatrix a, cmatrix b) { + return binary_op(a, b, [](T x, T y) { return x & y; }); +} + +template +cmatrix operator*(cmatrix a, cmatrix b) { + return binary_op(a, b, [](T x, T y) { return x * y; }); +} + +template +cmatrix operator|(cmatrix a, cmatrix b) { + return binary_op(a, b, [](T x, T y) { return x | y; }); +} + +// operation along axis +template +cmatrix op_axis(cmatrix a, int axis, F &&func) { + assert(axis == 0 || axis == 1); + auto x(a); + if (axis == 0) { + x = x.transpose(); + } + cmatrix y(x.N, x.M, 0); + for (int i=0; i +cmatrix max_along_axis(cmatrix a, int axis) { + return op_axis(a, axis, [](const std::vector& x) { return *std::max_element(x.begin(), x.end()); }); +} + +template +cmatrix getrange(cmatrix a, int is, int ie, int js, int je) { + assert(a.N >= ie && a.M >= je); + cmatrix b(ie-is, je-js, 0); + for (int i=is; i +void setrange(cmatrix &x, cmatrix y, int is, int ie, int js, int je) { + assert(x.N >= ie && x.M >= je); + assert(y.is_shape(ie-is, je-js)); + for (int i=is; i +void setrange(cmatrix &x, T y, int is, int ie, int js, int je) { + for (int i=is; i +cmatrix vector_to_matrix(std::vector x, int M) { + cmatrix y(x.size(), M, 0); + for (int i=0; i<(int)x.size(); i++) { + for (int j=0; j &sstep, vmatrix &ccontext) { + start = sstart; + end = eend; + step = sstep; + context = ccontext; + initialize(); + } + + STEP (const STEP &old) { + start = old.start; + end = old.end; + step = old.step; + context = old.context; + initialize(); + } + + STEP () { + start = 0; + end = 0; + throw std::runtime_error("STEP constructor called with no arguments"); + } + + virtual ~STEP () {} + + void initialize() { + int track_count = 0; + int num_tracks = step.size(); + std::set track_set; + for (int i=0; i(track_set.begin(), track_set.end()); + assert(bars_to_generate.size() > 0); + } + + std::vector get_tracks() const { + return tracks; + } + + std::set> get_bars_to_generate() const{ + return bars_to_generate; + } + + std::vector> get_bar_mapping() const { + return bar_mapping; + } + + int generated_bar_count() const { + return sum(step); + } + + int start; + int end; + vmatrix step; + vmatrix context; + +private: + std::set> bars_to_generate; + std::vector> bar_mapping; + std::vector tracks; +}; + +class HyperParam { +public: + HyperParam () { + _model_dim = 4; + _tracks_per_step = 1; + _bars_per_step = 4; + _shuffle = false; + _percentage = 100; + } + int model_dim() const { + return _model_dim; + } + int tracks_per_step() const { + return _tracks_per_step; + } + int bars_per_step() const { + return _bars_per_step; + } + bool shuffle() const { + return _shuffle; + } + int percentage() const { + return _percentage; + } + int _model_dim; + int _tracks_per_step; + int _bars_per_step; + bool _shuffle; + int _percentage; +}; + +void find_steps_inner(std::vector &steps, cmatrix &selection_matrix, cmatrix &resample_mask, cmatrix &ignore_mask, bool autoregressive, cmatrix &generated, midi::HyperParam *param) { + + int model_dim = param->model_dim(); + int tracks_per_step = clamp(param->tracks_per_step(), 1, selection_matrix.N); + int bars_per_step = clamp(param->bars_per_step(), 1, model_dim); + int current_num_steps = steps.size(); + int num_context = autoregressive ? model_dim - bars_per_step : (model_dim - bars_per_step) / 2; + int nt = selection_matrix.N; + int nb = selection_matrix.M; + + auto sel(selection_matrix); + auto covered = cmatrix(nt,nb,0); + auto tracks_to_consider = arange(0, nt, 1); + + sel = autoregressive ? sel & resample_mask : sel & ~resample_mask; + + std::vector> ijs; + for (int i=0; i<(int)sel.N; i=i+tracks_per_step) { + for (int j=0; j<(int)sel.M; j=j+bars_per_step) { + ijs.push_back( std::make_tuple(i,j) ); + } + } + + for (const auto &ij : ijs) { + int i = std::get<0>(ij); + int j = std::get<1>(ij); + int num_tracks = std::min(tracks_per_step,(int)sel.N-i); + auto kernel = cmatrix(num_tracks,model_dim,0); + auto step = cmatrix(nt,nb,0); + auto context = cmatrix(nt,nb,0); + + int t = 0; + if (autoregressive) { + // for the first step we have no generated material to + // condition on so we use entire model window + // after the first step (j>0) we only generate bars_per_step bars + int right_offset = std::max((j + model_dim) - nb,0); + t = std::min(j, nb - model_dim); + setrange(kernel, true, 0, num_tracks, (j>0)*(num_context+right_offset), model_dim); + } + else { + // we want to have the generated bars at the center + // this is not possible at beginning and end so we adjust for those cases + t = clamp(j - num_context, 0, nb - model_dim); + setrange(kernel, true, 0, num_tracks, j-t, j-t+bars_per_step); + } + + int a = i + num_tracks; + int b = t + model_dim; + setrange(step, getrange(sel, i, a, t, b) * kernel, i, a, t, b); + if (autoregressive) { + setrange(step, getrange(step, i, a, t, b) & ~getrange(generated, i, a, t, b), i, a, t, b); + } + setrange(context, ~getrange(ignore_mask, 0, nt, t, b) & ~getrange(step, 0, nt, t, b), 0, nt, t, b); + if (autoregressive) { + auto h = max_along_axis(sel, 1); + setrange(context, getrange((h * generated) | (~h * context), 0, nt, t, b), 0, nt, t, b); + } + + if (any(step)) { + steps.push_back(STEP(t, t+model_dim, step.m_data, context.m_data)); + } + + setrange(generated, getrange(generated, i, a, t, b) | getrange(step, i, a, t, b), i, a, t, b); + setrange(covered, getrange(covered, i, a, t, b) | kernel, i, a, t, b); + + } + + if (!all(covered)) { + throw std::runtime_error("PIECE IS ONLY PARTIALLY COVERED"); + } + + if ((!autoregressive) && (param->shuffle())) { + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(steps.begin() + current_num_steps, steps.end(), g); + } + if ((!autoregressive) && (param->percentage() < 100) && ((int)steps.size() > current_num_steps)) { + int non_autoreg_steps = steps.size() - current_num_steps; + int new_size = non_autoreg_steps * ((float)param->percentage() / 100.); + steps.resize(current_num_steps + std::max(new_size,1)); + } + +} \ No newline at end of file diff --git a/src/inference/sampling/multi_step_sample.h b/src/inference/sampling/multi_step_sample.h new file mode 100644 index 0000000000000000000000000000000000000000..494fbdf2268dbe3c90c082734aab38c7389700f1 --- /dev/null +++ b/src/inference/sampling/multi_step_sample.h @@ -0,0 +1,475 @@ +#pragma once + +#include +#include + +#include "callback_base.h" +#include "sample_internal.h" +#include "../../common/midi_parsing/util_protobuf.h" + +#include +#include + +#include "multi_step.h" + +namespace sampling { + +// Converts the status message into a track & bar matrix indicating which bars are selected +std::vector> status_to_selection_mask(midi::Status *status) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "status_to_selection_mask" ); + int ntracks = status->tracks_size(); + int nbars = status->tracks(0).selected_bars_size(); + std::vector> x(ntracks, std::vector(nbars,false)); + int track_num = 0; + for (const auto &track : status->tracks()) { + int bar_num = 0; + for (const auto &bar : track.selected_bars()) { + x[track_num][bar_num] = bar; + bar_num++; + } + track_num++; + } + return x; +} + +// Returns a boolean vector indicating which tracks to sample +std::vector status_to_resample_mask(midi::Status *status) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "status_to_resample_mask" ); + // get a boolean vector that indicates which tracks to resample + std::vector resample_mask; + for (const auto &track : status->tracks()) { + resample_mask.push_back( track.autoregressive() ); + } + return resample_mask; +} + +// Returns a boolean vector indicating which tracks to ignore +std::vector status_to_ignore_mask(midi::Status *status) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "status_to_ignore_mask" ); + std::vector ignore_mask; + for (const auto &track : status->tracks()) { + ignore_mask.push_back( track.ignore() ); + } + return ignore_mask; +} + + +void status_rehighlight(midi::Status *status, const std::set> &bar_list) { + int num_tracks = status->tracks_size(); + for (int track_num=0; track_nummutable_tracks(track_num); + int num_bars = track->selected_bars_size(); + track->clear_selected_bars(); + for (int bar_num=0; bar_numadd_selected_bars(x); + if ((track->autoregressive()) && (!x)) { + track->set_autoregressive( false ); + } + } + } +} + +midi::Status status_subset(midi::Status *status, int start_bar, int end_bar, const std::vector &track_indices) { + midi::Status subset; + subset.set_decode_final(status->decode_final()); + int track_count = 0; + for (const auto &track_index : track_indices) { + const midi::StatusTrack track = status->tracks(track_index); + midi::StatusTrack *t = subset.add_tracks(); + t->CopyFrom(track); + t->set_track_id(track_count); + t->clear_selected_bars(); + t->clear_bars(); + for (int i=start_bar; iadd_bars(); + b->CopyFrom(track.bars(i)); + t->add_selected_bars( track.selected_bars(i) ); + } + track_count++; + } + return subset; +} + +// Retrieve a subset of the Piece +midi::Piece piece_subset(midi::Piece* piece, int start_bar, int end_bar, const std::vector& track_indices) { + midi::Piece subset; + subset.set_resolution( piece->resolution() ); + subset.set_tempo( piece->tempo() ); + int track_count = 0; + for (const auto &track_index : track_indices) { + if (track_index >= piece->tracks_size()) { + throw std::runtime_error("TRYING TO ACCESS TRACK OUT OF RANGE. PIECE IS LIKELY MALFORMED"); + } + const midi::Track track = piece->tracks(track_index); + midi::Track *t = subset.add_tracks(); + t->CopyFrom(track); + t->clear_bars(); + for (int i=start_bar; iadd_bars(); + b->CopyFrom( track.bars(i) ); + b->clear_events(); + + for (const auto &event : track.bars(i).events()) { + b->add_events( subset.events_size() ); + midi::Event *e = subset.add_events(); + e->CopyFrom( piece->events(event) ); + } + } + track_count++; + } + return subset; +} + +void add_timesigs_to_status(midi::Piece *piece, midi::Status *status) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "add_timesigs_to_status" ); + int track_num = 0; + for (const auto &track : piece->tracks()) { + int bar_num = 0; + midi::StatusTrack *st = status->mutable_tracks(track_num); + for (const auto &bar : track.bars()) { + midi::StatusBar *sb; + if (st->bars_size() <= bar_num) { + sb = st->add_bars(); + } + else { + sb = st->mutable_bars(bar_num); + } + sb->set_ts_numerator( bar.ts_numerator() ); + sb->set_ts_denominator( bar.ts_denominator() ); + bar_num++; + } + track_num++; + } +} + +// We compute features first and then only override if the controls are not "ANY" +void override_piece_features(midi::Piece *piece, midi::Status *status, const std::shared_ptr &rep) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "override_piece_features" ); + compute_attribute_controls(rep, piece); + + // new override + override_attribute_controls(rep, piece, status); + + // legacy override + for (const auto &track : status->tracks()) { + midi::TrackFeatures *f = util_protobuf::GetTrackFeatures(piece, track.track_id()); + if (track.density() > 0) { + f->set_note_density_v2( track.density() - 1); + } + if (track.min_polyphony_q() > 0) { + f->set_min_polyphony_q( track.min_polyphony_q() - 1 ); + } + if (track.max_polyphony_q() > 0) { + f->set_max_polyphony_q( track.max_polyphony_q() - 1 ); + } + if (track.min_note_duration_q() > 0) { + f->set_min_note_duration_q( track.min_note_duration_q() - 1 ); + } + if (track.max_note_duration_q() > 0) { + f->set_max_note_duration_q( track.max_note_duration_q() - 1 ); + } + } +} + +void piece_insert(midi::Piece *piece, midi::Piece *x, const std::vector> &bar_mapping, bool verbose) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "piece_insert" ); + + for (const auto &ii : bar_mapping) { + if (std::get<0>(ii) >= x->tracks_size()) { + data_structures::LOGGER(data_structures::to_str("PIECE INSERT :: INVALID TRACK INDEX ", std::get<0>(ii), " FOR X")); + throw std::runtime_error("PIECE INSERT :: INVALID TRACK INDEX FOR X"); + } + if (std::get<2>(ii) >= piece->tracks_size()) { + throw std::runtime_error("PIECE INSERT :: INVALID TRACK INDEX FOR PIECE"); + } + const midi::Track src_track = x->tracks(std::get<0>(ii)); + const midi::Bar src = src_track.bars(std::get<1>(ii)); + midi::Track *dst_track = piece->mutable_tracks(std::get<2>(ii)); + midi::Bar *dst = dst_track->mutable_bars(std::get<3>(ii)); + + if (verbose) { + data_structures::LOGGER(data_structures::to_str("INSERTING (", std::get<0>(ii), ",", std::get<1>(ii), ") into (", std::get<2>(ii), ",", std::get<3>(ii), ")")); + } + + // overwrite instrument and track type (for autoregressive) + dst_track->set_track_type( src_track.track_type() ); + dst_track->set_instrument( src_track.instrument() ); + + // overwrite bar from src + dst->clear_events(); + for (const auto &event_index : src.events()) { + dst->add_events( piece->events_size() ); + midi::Event *e = piece->add_events(); + e->CopyFrom( x->events(event_index) ); + } + } +} + +// This function resamples and recomputes the event times using the delta values +void resample_delta(midi::Piece *p, std::shared_ptr ec) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_VERBOSE, "Resampling Piece with Delta values"); + int current_res = ec->resolution; + int target_res = ec->decode_resolution; + p->set_resolution(target_res); + p->set_internal_ticks_per_quarter(target_res); + int old_time, new_time, delta; + std::vector> events_cache; + // Get all events and store in cache vector + + int num_events = p->events_size(); + for (int event_index=0; event_indexevents(event_index); + old_time = e.time(); + delta = e.delta(); + // We round down to be safe + new_time = (int)(target_res * old_time / current_res); + //exclude negative times + new_time = std::max(new_time + delta, 0); + // Set new resampled time + e.set_time(new_time); + events_cache.push_back(std::make_tuple(event_index, e)); + } + // Sort events to replace in the correct order + sort(events_cache.begin(), events_cache.end(), [](std::tuple a, std::tuple b) { + return std::get<0>(a) < std::get<0>(b); + }); + // Clear all events now that they're cached + p->clear_events(); + // Reinject resampled events + for (const std::tuple &oe : events_cache) { + midi::Event *ne = p->add_events(); + ne->CopyFrom( std::get<1>(oe) ); + } + assert(num_events == p->events_size()); +} + + +std::vector find_steps(const std::vector> &sel, const std::vector &resample_mask, const std::vector &ignore_mask, midi::HyperParam *param) { + if ((sel.size() != resample_mask.size()) || (sel.size() != ignore_mask.size())) { + throw std::invalid_argument("find_steps :: selection, resample_mask and ignore_mask must be the same size"); + } + std::vector steps; + cmatrix selection(sel); + cmatrix generated = cmatrix(selection.N, selection.M, 0); + cmatrix resample = vector_to_matrix(resample_mask, selection.M); + cmatrix ignore = vector_to_matrix(ignore_mask, selection.M); + find_steps_inner(steps, selection, resample, ignore, true, generated, param); + find_steps_inner(steps, selection, resample, ignore, false, generated, param); + return steps; +} + +void sample_step(midi::Piece *piece, midi::Status *status, midi::HyperParam *param, const std::unique_ptr &model, const STEP *s, CallbackManager *callbacks) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "sample_step" ); + + // prepare the inputs for generation + midi::Piece step_piece = piece_subset(piece, s->start, s->end, s->get_tracks()); + midi::Status step_status = status_subset(status, s->start, s->end, s->get_tracks()); + status_rehighlight(&step_status, s->get_bars_to_generate()); + + // do generation + midi::Piece gen_piece = generate(&step_status, &step_piece, param, model, callbacks)[0]; + // NOTE : this inserts tracks that are just conditioned on as well + // insert generation into global piece + piece_insert(piece, &gen_piece, s->get_bar_mapping(), param->verbose()); + std::unique_ptr enc = enums::getEncoderFromString(model->meta.encoder()); + if (!enc.get()) { + throw std::invalid_argument("INVALID ENCODER"); + } + if (enc->config->use_microtiming && status->decode_final()) { + //resample_delta(piece, enc->config); + enc->resample_delta(piece); + } + override_piece_features(piece, status, enc->rep); +} + +// ============================== +// MAIN INFERENCE ENTRYPOINT +void sample(midi::Piece* piece, midi::Status* raw_status, midi::HyperParam* param, CallbackManager *callbacks) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "sample" ); + + //CheckIfDataExists + if ((!piece) || (!raw_status) || (!param)) { + throw std::invalid_argument("Piece, Status or HyperParam is malformed"); + } + + if ((callbacks) && (callbacks->is_cancelled())) { + return; + } + + // We create a new status with raw_status info, and then a pointer to access it indirectly. + midi::Status status_object(*raw_status); + midi::Status* status_pointer = &status_object; + + // try to load model + std::unique_ptr model = load_model(param); + + // Check if encoder exists + std::unique_ptr enc = enums::getEncoderFromString(model->meta.encoder()); + if (!enc.get()) { + throw std::invalid_argument("INVALID ENCODER"); + } + piece->set_resolution(enc->config->resolution); + param->set_internal_skip_preprocess(true); + param->set_batch_size(1); + + util_protobuf::validate_inputs(piece, status_pointer, param); + // before we start pad the piece if status references tracks + // that do not exist yet + util_protobuf::pad_piece_with_status(piece, status_pointer, param->model_dim()); + // add time-signatures from piece into the status + add_timesigs_to_status(piece, status_pointer); + // add features to piece when we are sampling auto-regressively + // as these are perhaps not yet in the piece + override_piece_features(piece, status_pointer, enc->rep); + + std::vector> selection_mask = status_to_selection_mask(status_pointer); + if (!any(selection_mask)) { + return; // nothing to do + } + + std::vector resample_mask = status_to_resample_mask(status_pointer); + std::vector ignore_mask = status_to_ignore_mask(status_pointer); + std::vector steps = find_steps(selection_mask, resample_mask, ignore_mask, param); + + if (steps.size() == 0) { + return; // nothing to be done + } + + // find the total number of bars to be generated + int bar_count = 0; + for (const auto &step : steps) { + bar_count += step.generated_bar_count(); + } + + // get order and reverse order of tracks + int nt = status_pointer->tracks_size(); + std::vector order(nt, 0); + std::vector reverse_order = arange(nt); + for (int track_num = 0; track_num < nt; track_num++) { + midi::StatusTrack* st = status_pointer->mutable_tracks(track_num); + order[st->track_id()] = track_num; + st->set_track_id(track_num); // now the mapping is the identity + } + std::sort(reverse_order.begin(), reverse_order.end(), + [&order](size_t i, size_t j) {return order[i] < order[j]; }); + util_protobuf::reorder_tracks(piece, order); + + for (int i=0; iset_decode_final(true); + } else { + status_pointer->set_decode_final(false); + } + STEP step = steps[i]; + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, data_structures::to_str("Sampling step :: decoding final = ", status_pointer->decode_final())); + sample_step(piece, status_pointer, param, model, &step, callbacks); + } + util_protobuf::reorder_tracks(piece, reverse_order); + std::string json_string_res = util_protobuf::protobuf_to_string(piece); +} + +std::vector> get_notes_py(std::string &piece_json, int track_start, int track_end, int bar_start, int bar_end, bool onset_only_drums) { + midi::Piece piece; + util_protobuf::string_to_protobuf(piece_json, &piece); + std::vector notes = util_protobuf::getNotes(&piece, track_start, track_end, bar_start, bar_end, onset_only_drums); + std::vector> notes_py; + for (const auto ¬e : notes) { + notes_py.push_back(std::make_tuple(note.start(), note.end(), note.pitch())); + } + return notes_py; +} + +void sort_notes(std::vector ¬es) { + std::sort(notes.begin(), notes.end(), [](const midi::Note &a, const midi::Note &b) { + if (a.start() == b.start()) { + return a.pitch() < b.pitch(); + } + return a.start() < b.start(); + }); +} + +// function that determines if two bars are equivalent +bool bars_are_equivalent(midi::Piece *pa, midi::Piece *pb, int track_num, int bar_num) { + std::vector notes_a = util_protobuf::getNotes(pa, track_num, track_num+1, bar_num, bar_num+1, true); + std::vector notes_b = util_protobuf::getNotes(pb, track_num, track_num+1, bar_num, bar_num+1, true); + if (notes_a.size() != notes_b.size()) { + return false; + } + sort_notes(notes_a); + sort_notes(notes_b); + for (int i=0; i<(int)notes_a.size(); i++) { + if ((notes_a[i].start() != notes_b[i].start()) || (notes_a[i].pitch() != notes_b[i].pitch())) { + return false; + } + } + return true; +} + +// function that determines if something has changed +// it returns a list of bars that are identical +std::vector> find_identical_bars(midi::Piece *input, midi::Piece *output, midi::Status *status) { + std::vector> identical_bars; + for (int track_num=0; track_numtracks_size(); track_num++) { + midi::StatusTrack track = status->tracks(track_num); + for (int bar_num=0; bar_num> identical_bars = find_identical_bars(&input, ¤t, status); + attempts++; + if (identical_bars.size() == 0) { + piece->CopyFrom(current); + return attempts; + } + if (callbacks) { + param->set_temperature( callbacks->update_temperature(param->temperature()) ); + } + } + return attempts; +} + +std::tuple sample_multi_step_py(std::string &piece_json, std::string &status_json, std::string ¶m_json, int max_attempts, sampling::CallbackManager *callbacks) { + midi::Piece piece; + midi::Status status; + midi::HyperParam hyperParam; + + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "to_proto"); + + util_protobuf::string_to_protobuf(piece_json, &piece); + util_protobuf::string_to_protobuf(status_json, &status); + util_protobuf::string_to_protobuf(param_json, &hyperParam); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "validating"); + + util_protobuf::validate_protobuf_fields(&piece, piece_json); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "piece"); + util_protobuf::validate_protobuf_fields(&status, status_json); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "status"); + util_protobuf::validate_protobuf_fields(&hyperParam, param_json); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "param"); + + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_VERBOSE, util_protobuf::protobuf_to_string(&status)); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_VERBOSE, util_protobuf::protobuf_to_string(&hyperParam)); + + int attempts = sample_multi_attempts(&piece, &status, &hyperParam, callbacks, max_attempts); + return std::make_tuple(util_protobuf::protobuf_to_string(&piece), attempts); +} + +} diff --git a/src/inference/sampling/sample_internal.h b/src/inference/sampling/sample_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..b775dd2476df72d0b2f855c9601ae3d66ac8699d --- /dev/null +++ b/src/inference/sampling/sample_internal.h @@ -0,0 +1,230 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../enum/model_type.h" +#include "../../common/data_structures/verbosity.h" +#include "control.h" +#include "callback_base.h" + +namespace sampling { + + class ModelMeta { + public: + torch::jit::Module model; + midi::ModelMetadata meta; + }; + + static const int NUM_LAYERS = 6; + + void load_checkpoint(const std::string &ckpt_path, const std::unique_ptr &m) { + try { + std::unordered_map loaded_extra_files; + loaded_extra_files["metadata.json"] = ""; + m->model = torch::jit::load(ckpt_path, torch::kCPU, loaded_extra_files); + if (loaded_extra_files["metadata.json"].size() == 0) { + throw std::runtime_error("ERROR LOADING MODEL : MODEL CONTAINS NO METADATA!"); + } + util_protobuf::string_to_protobuf(loaded_extra_files["metadata.json"], &m->meta); + data_structures::LOGGER( "MODEL METADATA :" ); + } + catch (const c10::Error& e) { + data_structures::LOGGER( e.what() ); + throw std::runtime_error("ERROR LOADING MODEL."); + } + } + + std::unique_ptr load_model(midi::HyperParam *param) { + auto model = std::make_unique(); + load_checkpoint(param->ckpt(), model); + if (model->meta.model_dim() != -1) { + param->set_model_dim(model->meta.model_dim()); + } + + model->meta.set_num_heads(8); + model->meta.set_num_layers(6); + + return model; + } + + void sample_inner(std::vector> &scon, std::vector> &seqs, torch::jit::Module *model, std::vector &inputs, midi::HyperParam *param, CallbackManager *callbacks) { + + if (!model) { + throw std::runtime_error("ERROR : MODEL IS INVALID."); + } + + torch::Tensor logits; + torch::jit::IValue past_key_values; + + auto outputs = model->forward(inputs).toTuple(); + logits = outputs->elements()[0].toTensor().index( + {torch::indexing::Slice(),-1,torch::indexing::Slice()}); + past_key_values = outputs->elements()[1]; + + + // get logits for first in batch + std::vector> masks_copy; + std::vector> logits_copy; + for (int i=0; i<(int)seqs.size(); i++) { + logits_copy.push_back(std::vector(logits[i].data_ptr(), logits[i].data_ptr() + logits[i].numel())); + } + + // set masks + std::vector> masked_tts; + int num_masked = 0; + for (int i=0; i<(int)seqs.size(); i++) { + std::vector unmasked_types; + std::vector mask = scon[i]->get_mask( seqs[i] ); + masks_copy.push_back( mask ); + masked_tts.push_back( scon[i]->rep->get_mask_token_types(mask) ); + scon[i]->rep->show_mask_token_types(mask); + if ((!scon[i]->finished) && (!param->internal_disable_masking())) { + for (int j=0; j<(int)mask.size(); j++) { + if (mask[j] == 0) { + logits[i][j] = -1 * std::numeric_limits::max(); // set this to a very small possibility + num_masked++; + } else { + unmasked_types.push_back(scon[i]->enc->rep->pretty_type(j)); + } + } + } + std::set s( unmasked_types.begin(), unmasked_types.end() ); + unmasked_types.assign( s.begin(), s.end() ); + for (auto strr : unmasked_types) { + std::cout << "NOT MASKED: " << strr << std::endl; + } + + if (param->mask_top_k() > 0) { + + std::mt19937 engine(time(NULL)); + + // optionally mask the top k tokens + bool can_mask = false; + std::vector token_types_to_mask = {midi::TOKEN_NOTE_ONSET, midi::TOKEN_TIME_ABSOLUTE_POS, midi::TOKEN_NOTE_DURATION}; + for (const auto &t : token_types_to_mask) { + if (masked_tts[i].count(t) > 0) { + can_mask = true; + break; + } + } + if ((can_mask) && (random_on_unit(&engine) < param->mask_top_k())) { + std::vector V(mask.size()); + std::iota(V.begin(),V.end(),0); + std::sort( V.begin(),V.end(), [&](int ii,int jj){ return (logits[i][ii] > logits[i][jj]).item(); }); + + for (int j=0; j<10; j++) { + if (j==0) { + logits[i][V[j]] = -1 * std::numeric_limits::max(); + num_masked++; + } + } + } + } + } + + if (param->sampling_seed() != -1) { + torch::manual_seed(param->sampling_seed()); + } + + float temperature = param->temperature(); + auto probs = (logits / temperature).softmax(1); + auto next_tokens = probs.multinomial(1); + + inputs.clear(); + inputs.push_back( next_tokens ); + inputs.push_back( past_key_values ); + + // add next token to the sequences + for (int i=0; i<(int)seqs.size(); i++) { + if (!scon[i]->finished) { + int next_token = next_tokens[i][0].item(); + data_structures::LOGGER(data_structures::to_str("SAMPLED :: ", scon[i]->enc->rep->pretty(next_token))); + seqs[i].push_back( next_token ); + + + if (callbacks) { + if ((scon[i]->enc->rep->is_token_type(next_token, midi::TOKEN_BAR_END)) || (scon[i]->enc->rep->is_token_type(next_token, midi::TOKEN_FILL_IN_END))) { + callbacks->on_bar_end(); + } + callbacks->on_prediction(logits_copy[i], next_token); + } + } + } + } + + void make_state(std::vector *state, int batch_size, midi::ModelMetadata *meta) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, "make_state" ); + for (int i=0; inum_layers(); i++) { + std::vector tuple; + for (int j=0; j<2; j++) { + tuple.push_back( torch::zeros({batch_size, meta->num_heads(), 0, meta->num_hidden()}) ); + } + state->push_back( torch::ivalue::Tuple::create(tuple) ); + } + } + + std::vector generate(midi::Status *status, midi::Piece *piece, midi::HyperParam *param, const std::unique_ptr &mm, CallbackManager *callbacks) { + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_DEBUG, "generate"); + data_structures::LOGGER(data_structures::VERBOSITY_LEVEL_TRACE, util_protobuf::protobuf_to_string(status)); + param->set_temperature( std::max((double)param->temperature(), 1e-6) ); // CAN'T HAVE ZERO TEMPERATURE + std::vector> scon; + for (int i=0; ibatch_size(); i++) { + scon.push_back( std::make_unique(piece, status, param, &mm->meta) ); + } + for (auto &sc : scon) { + data_structures::LOGGER("REG GRAPH" ); + sc->rg->graph.print_graphviz(); + } + std::vector prompt = scon[0]->prompt; + std::vector inputs; + std::vector> seqs = std::vector>(param->batch_size(), prompt); + scon[0]->rep->show(prompt); + + auto opts = torch::TensorOptions().dtype(torch::kInt64); + torch::Tensor x = torch::zeros({param->batch_size(), (int)prompt.size()}, opts); + for (int k=0; kbatch_size(); k++) { + for (int i=0; i<(int)prompt.size(); i++) { + x[k][i] = prompt[i]; + } + } + inputs.push_back( x ); + std::vector state; + if ((param) && (mm->meta.new_state())) { + make_state(&state, param->batch_size(), &mm->meta); + } + inputs.push_back(torch::ivalue::Tuple::create(state)); + + + bool terminated = false; + int num_steps = 0; + while (!scon[0]->finished) { + sample_inner(scon, seqs, &mm->model, inputs, param, callbacks); + num_steps++; + if ((param->max_steps() > 0) && (num_steps >= param->max_steps())) { + terminated = true; + break; + } + if ((callbacks) && (callbacks->is_cancelled())) { + terminated = true; + break; + } + } + scon[0]->enc->config->decode_final = status->decode_final(); + scon[0]->rep->show(seqs[0]); + std::vector output(param->batch_size()); + if (!terminated) { + scon[0]->enc->tokens_to_json_array(seqs, output); + scon[0]->finalize(&output[0]); // batch size should be 1 anyways + } + return output; + } + +} \ No newline at end of file diff --git a/src/inference/version.h b/src/inference/version.h new file mode 100644 index 0000000000000000000000000000000000000000..793a3cca79fdea1f9900b5a55f592e298dead896 --- /dev/null +++ b/src/inference/version.h @@ -0,0 +1,2 @@ +#include +std::string version() { return "2020-08-19_09-17-11"; } \ No newline at end of file diff --git a/src/inference/xxh/xxh3.h b/src/inference/xxh/xxh3.h new file mode 100644 index 0000000000000000000000000000000000000000..882da0e77e4b4f4c26a2f5634d32893809610984 --- /dev/null +++ b/src/inference/xxh/xxh3.h @@ -0,0 +1,2704 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Development source file for `xxh3` + * Copyright (C) 2019-2020 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/* + * Note: This file is separated for development purposes. + * It will be integrated into `xxhash.h` when development stage is completed. + * + * Credit: most of the work on vectorial and asm variants comes from @easyaspi314 + */ + +#ifndef XXH3_H_1397135465 +#define XXH3_H_1397135465 + +/* === Dependencies === */ +#ifndef XXHASH_H_5627135585666179 +/* special: when including `xxh3.h` directly, turn on XXH_INLINE_ALL */ +# undef XXH_INLINE_ALL /* avoid redefinition */ +# define XXH_INLINE_ALL +#endif +#include "xxhash.h" + + +/* === Compiler specifics === */ + +#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */ +# define XXH_RESTRICT restrict +#else +/* Note: it might be useful to define __restrict or __restrict__ for some C++ compilers */ +# define XXH_RESTRICT /* disable */ +#endif + +#if (defined(__GNUC__) && (__GNUC__ >= 3)) \ + || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) \ + || defined(__clang__) +# define XXH_likely(x) __builtin_expect(x, 1) +# define XXH_unlikely(x) __builtin_expect(x, 0) +#else +# define XXH_likely(x) (x) +# define XXH_unlikely(x) (x) +#endif + +#if defined(__GNUC__) +# if defined(__AVX2__) +# include +# elif defined(__SSE2__) +# include +# elif defined(__ARM_NEON__) || defined(__ARM_NEON) +# define inline __inline__ /* clang bug */ +# include +# undef inline +# endif +#elif defined(_MSC_VER) +# include +#endif + +/* + * One goal of XXH3 is to make it fast on both 32-bit and 64-bit, while + * remaining a true 64-bit/128-bit hash function. + * + * This is done by prioritizing a subset of 64-bit operations that can be + * emulated without too many steps on the average 32-bit machine. + * + * For example, these two lines seem similar, and run equally fast on 64-bit: + * + * xxh_u64 x; + * x ^= (x >> 47); // good + * x ^= (x >> 13); // bad + * + * However, to a 32-bit machine, there is a major difference. + * + * x ^= (x >> 47) looks like this: + * + * x.lo ^= (x.hi >> (47 - 32)); + * + * while x ^= (x >> 13) looks like this: + * + * // note: funnel shifts are not usually cheap. + * x.lo ^= (x.lo >> 13) | (x.hi << (32 - 13)); + * x.hi ^= (x.hi >> 13); + * + * The first one is significantly faster than the second, simply because the + * shift is larger than 32. This means: + * - All the bits we need are in the upper 32 bits, so we can ignore the lower + * 32 bits in the shift. + * - The shift result will always fit in the lower 32 bits, and therefore, + * we can ignore the upper 32 bits in the xor. + * + * Thanks to this optimization, XXH3 only requires these features to be efficient: + * + * - Usable unaligned access + * - A 32-bit or 64-bit ALU + * - If 32-bit, a decent ADC instruction + * - A 32 or 64-bit multiply with a 64-bit result + * - For the 128-bit variant, a decent byteswap helps short inputs. + * + * The first two are already required by XXH32, and almost all 32-bit and 64-bit + * platforms which can run XXH32 can run XXH3 efficiently. + * + * Thumb-1, the classic 16-bit only subset of ARM's instruction set, is one + * notable exception. + * + * First of all, Thumb-1 lacks support for the UMULL instruction which + * performs the important long multiply. This means numerous __aeabi_lmul + * calls. + * + * Second of all, the 8 functional registers are just not enough. + * Setup for __aeabi_lmul, byteshift loads, pointers, and all arithmetic need + * Lo registers, and this shuffling results in thousands more MOVs than A32. + * + * A32 and T32 don't have this limitation. They can access all 14 registers, + * do a 32->64 multiply with UMULL, and the flexible operand allowing free + * shifts is helpful, too. + * + * Therefore, we do a quick sanity check. + * + * If compiling Thumb-1 for a target which supports ARM instructions, we will + * emit a warning, as it is not a "sane" platform to compile for. + * + * Usually, if this happens, it is because of an accident and you probably need + * to specify -march, as you likely meant to compile for a newer architecture. + */ +#if defined(__thumb__) && !defined(__thumb2__) && defined(__ARM_ARCH_ISA_ARM) +# warning "XXH3 is highly inefficient without ARM or Thumb-2." +#endif + +/* ========================================== + * Vectorization detection + * ========================================== */ +#define XXH_SCALAR 0 /* Portable scalar version */ +#define XXH_SSE2 1 /* SSE2 for Pentium 4 and all x86_64 */ +#define XXH_AVX2 2 /* AVX2 for Haswell and Bulldozer */ +#define XXH_AVX512 3 /* AVX512 for Skylake and Icelake */ +#define XXH_NEON 4 /* NEON for most ARMv7-A and all AArch64 */ +#define XXH_VSX 5 /* VSX and ZVector for POWER8/z13 */ + +#ifndef XXH_VECTOR /* can be defined on command line */ +# if defined(__AVX512F__) +# define XXH_VECTOR XXH_AVX512 +# elif defined(__AVX2__) +# define XXH_VECTOR XXH_AVX2 +# elif defined(__SSE2__) || defined(_M_AMD64) || defined(_M_X64) || (defined(_M_IX86_FP) && (_M_IX86_FP == 2)) +# define XXH_VECTOR XXH_SSE2 +# elif defined(__GNUC__) /* msvc support maybe later */ \ + && (defined(__ARM_NEON__) || defined(__ARM_NEON)) \ + && (defined(__LITTLE_ENDIAN__) /* We only support little endian NEON */ \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) +# define XXH_VECTOR XXH_NEON +# elif (defined(__PPC64__) && defined(__POWER8_VECTOR__)) \ + || (defined(__s390x__) && defined(__VEC__)) \ + && defined(__GNUC__) /* TODO: IBM XL */ +# define XXH_VECTOR XXH_VSX +# else +# define XXH_VECTOR XXH_SCALAR +# endif +#endif + +/* + * Controls the alignment of the accumulator, + * for compatibility with aligned vector loads, which are usually faster. + */ +#ifndef XXH_ACC_ALIGN +# if defined(XXH_X86DISPATCH) +# define XXH_ACC_ALIGN 64 /* for compatibility with avx512 */ +# elif XXH_VECTOR == XXH_SCALAR /* scalar */ +# define XXH_ACC_ALIGN 8 +# elif XXH_VECTOR == XXH_SSE2 /* sse2 */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_AVX2 /* avx2 */ +# define XXH_ACC_ALIGN 32 +# elif XXH_VECTOR == XXH_NEON /* neon */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_VSX /* vsx */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_AVX512 /* avx512 */ +# define XXH_ACC_ALIGN 64 +# endif +#endif + +#if defined(XXH_X86DISPATCH) || XXH_VECTOR == XXH_SSE2 \ + || XXH_VECTOR == XXH_AVX2 || XXH_VECTOR == XXH_AVX512 +# define XXH_SEC_ALIGN XXH_ACC_ALIGN +#else +# define XXH_SEC_ALIGN 8 +#endif + +/* + * UGLY HACK: + * GCC usually generates the best code with -O3 for xxHash. + * + * However, when targeting AVX2, it is overzealous in its unrolling resulting + * in code roughly 3/4 the speed of Clang. + * + * There are other issues, such as GCC splitting _mm256_loadu_si256 into + * _mm_loadu_si128 + _mm256_inserti128_si256. This is an optimization which + * only applies to Sandy and Ivy Bridge... which don't even support AVX2. + * + * That is why when compiling the AVX2 version, it is recommended to use either + * -O2 -mavx2 -march=haswell + * or + * -O2 -mavx2 -mno-avx256-split-unaligned-load + * for decent performance, or to use Clang instead. + * + * Fortunately, we can control the first one with a pragma that forces GCC into + * -O2, but the other one we can't control without "failed to inline always + * inline function due to target mismatch" warnings. + */ +#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \ + && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__OPTIMIZE__) && !defined(__OPTIMIZE_SIZE__) /* respect -O0 and -Os */ +# pragma GCC push_options +# pragma GCC optimize("-O2") +#endif + + +#if XXH_VECTOR == XXH_NEON +/* + * NEON's setup for vmlal_u32 is a little more complicated than it is on + * SSE2, AVX2, and VSX. + * + * While PMULUDQ and VMULEUW both perform a mask, VMLAL.U32 performs an upcast. + * + * To do the same operation, the 128-bit 'Q' register needs to be split into + * two 64-bit 'D' registers, performing this operation:: + * + * [ a | b ] + * | '---------. .--------' | + * | x | + * | .---------' '--------. | + * [ a & 0xFFFFFFFF | b & 0xFFFFFFFF ],[ a >> 32 | b >> 32 ] + * + * Due to significant changes in aarch64, the fastest method for aarch64 is + * completely different than the fastest method for ARMv7-A. + * + * ARMv7-A treats D registers as unions overlaying Q registers, so modifying + * D11 will modify the high half of Q5. This is similar to how modifying AH + * will only affect bits 8-15 of AX on x86. + * + * VZIP takes two registers, and puts even lanes in one register and odd lanes + * in the other. + * + * On ARMv7-A, this strangely modifies both parameters in place instead of + * taking the usual 3-operand form. + * + * Therefore, if we want to do this, we can simply use a D-form VZIP.32 on the + * lower and upper halves of the Q register to end up with the high and low + * halves where we want - all in one instruction. + * + * vzip.32 d10, d11 @ d10 = { d10[0], d11[0] }; d11 = { d10[1], d11[1] } + * + * Unfortunately we need inline assembly for this: Instructions modifying two + * registers at once is not possible in GCC or Clang's IR, and they have to + * create a copy. + * + * aarch64 requires a different approach. + * + * In order to make it easier to write a decent compiler for aarch64, many + * quirks were removed, such as conditional execution. + * + * NEON was also affected by this. + * + * aarch64 cannot access the high bits of a Q-form register, and writes to a + * D-form register zero the high bits, similar to how writes to W-form scalar + * registers (or DWORD registers on x86_64) work. + * + * The formerly free vget_high intrinsics now require a vext (with a few + * exceptions) + * + * Additionally, VZIP was replaced by ZIP1 and ZIP2, which are the equivalent + * of PUNPCKL* and PUNPCKH* in SSE, respectively, in order to only modify one + * operand. + * + * The equivalent of the VZIP.32 on the lower and upper halves would be this + * mess: + * + * ext v2.4s, v0.4s, v0.4s, #2 // v2 = { v0[2], v0[3], v0[0], v0[1] } + * zip1 v1.2s, v0.2s, v2.2s // v1 = { v0[0], v2[0] } + * zip2 v0.2s, v0.2s, v1.2s // v0 = { v0[1], v2[1] } + * + * Instead, we use a literal downcast, vmovn_u64 (XTN), and vshrn_n_u64 (SHRN): + * + * shrn v1.2s, v0.2d, #32 // v1 = (uint32x2_t)(v0 >> 32); + * xtn v0.2s, v0.2d // v0 = (uint32x2_t)(v0 & 0xFFFFFFFF); + * + * This is available on ARMv7-A, but is less efficient than a single VZIP.32. + */ + +/* + * Function-like macro: + * void XXH_SPLIT_IN_PLACE(uint64x2_t &in, uint32x2_t &outLo, uint32x2_t &outHi) + * { + * outLo = (uint32x2_t)(in & 0xFFFFFFFF); + * outHi = (uint32x2_t)(in >> 32); + * in = UNDEFINED; + * } + */ +# if !defined(XXH_NO_VZIP_HACK) /* define to disable */ \ + && defined(__GNUC__) \ + && !defined(__aarch64__) && !defined(__arm64__) +# define XXH_SPLIT_IN_PLACE(in, outLo, outHi) \ + do { \ + /* Undocumented GCC/Clang operand modifier: %e0 = lower D half, %f0 = upper D half */ \ + /* https://github.com/gcc-mirror/gcc/blob/38cf91e5/gcc/config/arm/arm.c#L22486 */ \ + /* https://github.com/llvm-mirror/llvm/blob/2c4ca683/lib/Target/ARM/ARMAsmPrinter.cpp#L399 */ \ + __asm__("vzip.32 %e0, %f0" : "+w" (in)); \ + (outLo) = vget_low_u32 (vreinterpretq_u32_u64(in)); \ + (outHi) = vget_high_u32(vreinterpretq_u32_u64(in)); \ + } while (0) +# else +# define XXH_SPLIT_IN_PLACE(in, outLo, outHi) \ + do { \ + (outLo) = vmovn_u64 (in); \ + (outHi) = vshrn_n_u64 ((in), 32); \ + } while (0) +# endif +#endif /* XXH_VECTOR == XXH_NEON */ + +/* + * VSX and Z Vector helpers. + * + * This is very messy, and any pull requests to clean this up are welcome. + * + * There are a lot of problems with supporting VSX and s390x, due to + * inconsistent intrinsics, spotty coverage, and multiple endiannesses. + */ +#if XXH_VECTOR == XXH_VSX +# if defined(__s390x__) +# include +# else +# include +# endif + +# undef vector /* Undo the pollution */ + +typedef __vector unsigned long long xxh_u64x2; +typedef __vector unsigned char xxh_u8x16; +typedef __vector unsigned xxh_u32x4; + +# ifndef XXH_VSX_BE +# if defined(__BIG_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +# define XXH_VSX_BE 1 +# elif defined(__VEC_ELEMENT_REG_ORDER__) && __VEC_ELEMENT_REG_ORDER__ == __ORDER_BIG_ENDIAN__ +# warning "-maltivec=be is not recommended. Please use native endianness." +# define XXH_VSX_BE 1 +# else +# define XXH_VSX_BE 0 +# endif +# endif /* !defined(XXH_VSX_BE) */ + +# if XXH_VSX_BE +/* A wrapper for POWER9's vec_revb. */ +# if defined(__POWER9_VECTOR__) || (defined(__clang__) && defined(__s390x__)) +# define XXH_vec_revb vec_revb +# else +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_revb(xxh_u64x2 val) +{ + xxh_u8x16 const vByteSwap = { 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00, + 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 }; + return vec_perm(val, val, vByteSwap); +} +# endif +# endif /* XXH_VSX_BE */ + +/* + * Performs an unaligned load and byte swaps it on big endian. + */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_loadu(const void *ptr) +{ + xxh_u64x2 ret; + memcpy(&ret, ptr, sizeof(xxh_u64x2)); +# if XXH_VSX_BE + ret = XXH_vec_revb(ret); +# endif + return ret; +} + +/* + * vec_mulo and vec_mule are very problematic intrinsics on PowerPC + * + * These intrinsics weren't added until GCC 8, despite existing for a while, + * and they are endian dependent. Also, their meaning swap depending on version. + * */ +# if defined(__s390x__) + /* s390x is always big endian, no issue on this platform */ +# define XXH_vec_mulo vec_mulo +# define XXH_vec_mule vec_mule +# elif defined(__clang__) && XXH_HAS_BUILTIN(__builtin_altivec_vmuleuw) +/* Clang has a better way to control this, we can just use the builtin which doesn't swap. */ +# define XXH_vec_mulo __builtin_altivec_vmulouw +# define XXH_vec_mule __builtin_altivec_vmuleuw +# else +/* gcc needs inline assembly */ +/* Adapted from https://github.com/google/highwayhash/blob/master/highwayhash/hh_vsx.h. */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mulo(xxh_u32x4 a, xxh_u32x4 b) +{ + xxh_u64x2 result; + __asm__("vmulouw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b)); + return result; +} +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mule(xxh_u32x4 a, xxh_u32x4 b) +{ + xxh_u64x2 result; + __asm__("vmuleuw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b)); + return result; +} +# endif /* XXH_vec_mulo, XXH_vec_mule */ +#endif /* XXH_VECTOR == XXH_VSX */ + + +/* prefetch + * can be disabled, by declaring XXH_NO_PREFETCH build macro */ +#if defined(XXH_NO_PREFETCH) +# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */ +#else +# if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_I86)) /* _mm_prefetch() is not defined outside of x86/x64 */ +# include /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */ +# define XXH_PREFETCH(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0) +# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) ) +# define XXH_PREFETCH(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) +# else +# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */ +# endif +#endif /* XXH_NO_PREFETCH */ + + +/* ========================================== + * XXH3 default settings + * ========================================== */ + +#define XXH_SECRET_DEFAULT_SIZE 192 /* minimum XXH3_SECRET_SIZE_MIN */ + +#if (XXH_SECRET_DEFAULT_SIZE < XXH3_SECRET_SIZE_MIN) +# error "default keyset is not large enough" +#endif + +/* Pseudorandom secret taken directly from FARSH */ +XXH_ALIGN(64) static const xxh_u8 XXH3_kSecret[XXH_SECRET_DEFAULT_SIZE] = { + 0xb8, 0xfe, 0x6c, 0x39, 0x23, 0xa4, 0x4b, 0xbe, 0x7c, 0x01, 0x81, 0x2c, 0xf7, 0x21, 0xad, 0x1c, + 0xde, 0xd4, 0x6d, 0xe9, 0x83, 0x90, 0x97, 0xdb, 0x72, 0x40, 0xa4, 0xa4, 0xb7, 0xb3, 0x67, 0x1f, + 0xcb, 0x79, 0xe6, 0x4e, 0xcc, 0xc0, 0xe5, 0x78, 0x82, 0x5a, 0xd0, 0x7d, 0xcc, 0xff, 0x72, 0x21, + 0xb8, 0x08, 0x46, 0x74, 0xf7, 0x43, 0x24, 0x8e, 0xe0, 0x35, 0x90, 0xe6, 0x81, 0x3a, 0x26, 0x4c, + 0x3c, 0x28, 0x52, 0xbb, 0x91, 0xc3, 0x00, 0xcb, 0x88, 0xd0, 0x65, 0x8b, 0x1b, 0x53, 0x2e, 0xa3, + 0x71, 0x64, 0x48, 0x97, 0xa2, 0x0d, 0xf9, 0x4e, 0x38, 0x19, 0xef, 0x46, 0xa9, 0xde, 0xac, 0xd8, + 0xa8, 0xfa, 0x76, 0x3f, 0xe3, 0x9c, 0x34, 0x3f, 0xf9, 0xdc, 0xbb, 0xc7, 0xc7, 0x0b, 0x4f, 0x1d, + 0x8a, 0x51, 0xe0, 0x4b, 0xcd, 0xb4, 0x59, 0x31, 0xc8, 0x9f, 0x7e, 0xc9, 0xd9, 0x78, 0x73, 0x64, + 0xea, 0xc5, 0xac, 0x83, 0x34, 0xd3, 0xeb, 0xc3, 0xc5, 0x81, 0xa0, 0xff, 0xfa, 0x13, 0x63, 0xeb, + 0x17, 0x0d, 0xdd, 0x51, 0xb7, 0xf0, 0xda, 0x49, 0xd3, 0x16, 0x55, 0x26, 0x29, 0xd4, 0x68, 0x9e, + 0x2b, 0x16, 0xbe, 0x58, 0x7d, 0x47, 0xa1, 0xfc, 0x8f, 0xf8, 0xb8, 0xd1, 0x7a, 0xd0, 0x31, 0xce, + 0x45, 0xcb, 0x3a, 0x8f, 0x95, 0x16, 0x04, 0x28, 0xaf, 0xd7, 0xfb, 0xca, 0xbb, 0x4b, 0x40, 0x7e, +}; + + +#ifdef XXH_OLD_NAMES +# define kSecret XXH3_kSecret +#endif + +/* + * Calculates a 32-bit to 64-bit long multiply. + * + * Wraps __emulu on MSVC x86 because it tends to call __allmul when it doesn't + * need to (but it shouldn't need to anyways, it is about 7 instructions to do + * a 64x64 multiply...). Since we know that this will _always_ emit MULL, we + * use that instead of the normal method. + * + * If you are compiling for platforms like Thumb-1 and don't have a better option, + * you may also want to write your own long multiply routine here. + * + * XXH_FORCE_INLINE xxh_u64 XXH_mult32to64(xxh_u64 x, xxh_u64 y) + * { + * return (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF); + * } + */ +#if defined(_MSC_VER) && defined(_M_IX86) +# include +# define XXH_mult32to64(x, y) __emulu((unsigned)(x), (unsigned)(y)) +#else +/* + * Downcast + upcast is usually better than masking on older compilers like + * GCC 4.2 (especially 32-bit ones), all without affecting newer compilers. + * + * The other method, (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF), will AND both operands + * and perform a full 64x64 multiply -- entirely redundant on 32-bit. + */ +# define XXH_mult32to64(x, y) ((xxh_u64)(xxh_u32)(x) * (xxh_u64)(xxh_u32)(y)) +#endif + +/* + * Calculates a 64->128-bit long multiply. + * + * Uses __uint128_t and _umul128 if available, otherwise uses a scalar version. + */ +static XXH128_hash_t +XXH_mult64to128(xxh_u64 lhs, xxh_u64 rhs) +{ + /* + * GCC/Clang __uint128_t method. + * + * On most 64-bit targets, GCC and Clang define a __uint128_t type. + * This is usually the best way as it usually uses a native long 64-bit + * multiply, such as MULQ on x86_64 or MUL + UMULH on aarch64. + * + * Usually. + * + * Despite being a 32-bit platform, Clang (and emscripten) define this type + * despite not having the arithmetic for it. This results in a laggy + * compiler builtin call which calculates a full 128-bit multiply. + * In that case it is best to use the portable one. + * https://github.com/Cyan4973/xxHash/issues/211#issuecomment-515575677 + */ +#if defined(__GNUC__) && !defined(__wasm__) \ + && defined(__SIZEOF_INT128__) \ + || (defined(_INTEGRAL_MAX_BITS) && _INTEGRAL_MAX_BITS >= 128) + + __uint128_t const product = (__uint128_t)lhs * (__uint128_t)rhs; + XXH128_hash_t r128; + r128.low64 = (xxh_u64)(product); + r128.high64 = (xxh_u64)(product >> 64); + return r128; + + /* + * MSVC for x64's _umul128 method. + * + * xxh_u64 _umul128(xxh_u64 Multiplier, xxh_u64 Multiplicand, xxh_u64 *HighProduct); + * + * This compiles to single operand MUL on x64. + */ +#elif defined(_M_X64) || defined(_M_IA64) + +#ifndef _MSC_VER +# pragma intrinsic(_umul128) +#endif + xxh_u64 product_high; + xxh_u64 const product_low = _umul128(lhs, rhs, &product_high); + XXH128_hash_t r128; + r128.low64 = product_low; + r128.high64 = product_high; + return r128; + +#else + /* + * Portable scalar method. Optimized for 32-bit and 64-bit ALUs. + * + * This is a fast and simple grade school multiply, which is shown below + * with base 10 arithmetic instead of base 0x100000000. + * + * 9 3 // D2 lhs = 93 + * x 7 5 // D2 rhs = 75 + * ---------- + * 1 5 // D2 lo_lo = (93 % 10) * (75 % 10) = 15 + * 4 5 | // D2 hi_lo = (93 / 10) * (75 % 10) = 45 + * 2 1 | // D2 lo_hi = (93 % 10) * (75 / 10) = 21 + * + 6 3 | | // D2 hi_hi = (93 / 10) * (75 / 10) = 63 + * --------- + * 2 7 | // D2 cross = (15 / 10) + (45 % 10) + 21 = 27 + * + 6 7 | | // D2 upper = (27 / 10) + (45 / 10) + 63 = 67 + * --------- + * 6 9 7 5 // D4 res = (27 * 10) + (15 % 10) + (67 * 100) = 6975 + * + * The reasons for adding the products like this are: + * 1. It avoids manual carry tracking. Just like how + * (9 * 9) + 9 + 9 = 99, the same applies with this for UINT64_MAX. + * This avoids a lot of complexity. + * + * 2. It hints for, and on Clang, compiles to, the powerful UMAAL + * instruction available in ARM's Digital Signal Processing extension + * in 32-bit ARMv6 and later, which is shown below: + * + * void UMAAL(xxh_u32 *RdLo, xxh_u32 *RdHi, xxh_u32 Rn, xxh_u32 Rm) + * { + * xxh_u64 product = (xxh_u64)*RdLo * (xxh_u64)*RdHi + Rn + Rm; + * *RdLo = (xxh_u32)(product & 0xFFFFFFFF); + * *RdHi = (xxh_u32)(product >> 32); + * } + * + * This instruction was designed for efficient long multiplication, and + * allows this to be calculated in only 4 instructions at speeds + * comparable to some 64-bit ALUs. + * + * 3. It isn't terrible on other platforms. Usually this will be a couple + * of 32-bit ADD/ADCs. + */ + + /* First calculate all of the cross products. */ + xxh_u64 const lo_lo = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs & 0xFFFFFFFF); + xxh_u64 const hi_lo = XXH_mult32to64(lhs >> 32, rhs & 0xFFFFFFFF); + xxh_u64 const lo_hi = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs >> 32); + xxh_u64 const hi_hi = XXH_mult32to64(lhs >> 32, rhs >> 32); + + /* Now add the products together. These will never overflow. */ + xxh_u64 const cross = (lo_lo >> 32) + (hi_lo & 0xFFFFFFFF) + lo_hi; + xxh_u64 const upper = (hi_lo >> 32) + (cross >> 32) + hi_hi; + xxh_u64 const lower = (cross << 32) | (lo_lo & 0xFFFFFFFF); + + XXH128_hash_t r128; + r128.low64 = lower; + r128.high64 = upper; + return r128; +#endif +} + +/* + * Does a 64-bit to 128-bit multiply, then XOR folds it. + * + * The reason for the separate function is to prevent passing too many structs + * around by value. This will hopefully inline the multiply, but we don't force it. + */ +static xxh_u64 +XXH3_mul128_fold64(xxh_u64 lhs, xxh_u64 rhs) +{ + XXH128_hash_t product = XXH_mult64to128(lhs, rhs); + return product.low64 ^ product.high64; +} + +/* Seems to produce slightly better code on GCC for some reason. */ +XXH_FORCE_INLINE xxh_u64 XXH_xorshift64(xxh_u64 v64, int shift) +{ + XXH_ASSERT(0 <= shift && shift < 64); + return v64 ^ (v64 >> shift); +} + +/* + * This is a fast avalanche stage, + * suitable when input bits are already partially mixed + */ +static XXH64_hash_t XXH3_avalanche(xxh_u64 h64) +{ + h64 = XXH_xorshift64(h64, 37); + h64 *= 0x165667919E3779F9ULL; + h64 = XXH_xorshift64(h64, 32); + return h64; +} + +/* + * This is a stronger avalanche, + * inspired by Pelle Evensen's rrmxmx + * preferable when input has not been previously mixed + */ +static XXH64_hash_t XXH3_rrmxmx(xxh_u64 h64, xxh_u64 len) +{ + /* this mix is inspired by Pelle Evensen's rrmxmx */ + h64 ^= XXH_rotl64(h64, 49) ^ XXH_rotl64(h64, 24); + h64 *= 0x9FB21C651E98DF25ULL; + h64 ^= (h64 >> 35) + len ; + h64 *= 0x9FB21C651E98DF25ULL; + return XXH_xorshift64(h64, 28); +} + + +/* ========================================== + * Short keys + * ========================================== + * One of the shortcomings of XXH32 and XXH64 was that their performance was + * sub-optimal on short lengths. It used an iterative algorithm which strongly + * favored lengths that were a multiple of 4 or 8. + * + * Instead of iterating over individual inputs, we use a set of single shot + * functions which piece together a range of lengths and operate in constant time. + * + * Additionally, the number of multiplies has been significantly reduced. This + * reduces latency, especially when emulating 64-bit multiplies on 32-bit. + * + * Depending on the platform, this may or may not be faster than XXH32, but it + * is almost guaranteed to be faster than XXH64. + */ + +/* + * At very short lengths, there isn't enough input to fully hide secrets, or use + * the entire secret. + * + * There is also only a limited amount of mixing we can do before significantly + * impacting performance. + * + * Therefore, we use different sections of the secret and always mix two secret + * samples with an XOR. This should have no effect on performance on the + * seedless or withSeed variants because everything _should_ be constant folded + * by modern compilers. + * + * The XOR mixing hides individual parts of the secret and increases entropy. + * + * This adds an extra layer of strength for custom secrets. + */ +XXH_FORCE_INLINE XXH64_hash_t +XXH3_len_1to3_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(1 <= len && len <= 3); + XXH_ASSERT(secret != NULL); + /* + * len = 1: combined = { input[0], 0x01, input[0], input[0] } + * len = 2: combined = { input[1], 0x02, input[0], input[1] } + * len = 3: combined = { input[2], 0x03, input[0], input[1] } + */ + { xxh_u8 const c1 = input[0]; + xxh_u8 const c2 = input[len >> 1]; + xxh_u8 const c3 = input[len - 1]; + xxh_u32 const combined = ((xxh_u32)c1 << 16) | ((xxh_u32)c2 << 24) + | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8); + xxh_u64 const bitflip = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed; + xxh_u64 const keyed = (xxh_u64)combined ^ bitflip; + return XXH64_avalanche(keyed); + } +} + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_len_4to8_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(4 <= len && len < 8); + seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32; + { xxh_u32 const input1 = XXH_readLE32(input); + xxh_u32 const input2 = XXH_readLE32(input + len - 4); + xxh_u64 const bitflip = (XXH_readLE64(secret+8) ^ XXH_readLE64(secret+16)) - seed; + xxh_u64 const input64 = input2 + (((xxh_u64)input1) << 32); + xxh_u64 const keyed = input64 ^ bitflip; + return XXH3_rrmxmx(keyed, len); + } +} + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_len_9to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(8 <= len && len <= 16); + { xxh_u64 const bitflip1 = (XXH_readLE64(secret+24) ^ XXH_readLE64(secret+32)) + seed; + xxh_u64 const bitflip2 = (XXH_readLE64(secret+40) ^ XXH_readLE64(secret+48)) - seed; + xxh_u64 const input_lo = XXH_readLE64(input) ^ bitflip1; + xxh_u64 const input_hi = XXH_readLE64(input + len - 8) ^ bitflip2; + xxh_u64 const acc = len + + XXH_swap64(input_lo) + input_hi + + XXH3_mul128_fold64(input_lo, input_hi); + return XXH3_avalanche(acc); + } +} + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_len_0to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(len <= 16); + { if (XXH_likely(len > 8)) return XXH3_len_9to16_64b(input, len, secret, seed); + if (XXH_likely(len >= 4)) return XXH3_len_4to8_64b(input, len, secret, seed); + if (len) return XXH3_len_1to3_64b(input, len, secret, seed); + return XXH64_avalanche(seed ^ (XXH_readLE64(secret+56) ^ XXH_readLE64(secret+64))); + } +} + +/* + * DISCLAIMER: There are known *seed-dependent* multicollisions here due to + * multiplication by zero, affecting hashes of lengths 17 to 240. + * + * However, they are very unlikely. + * + * Keep this in mind when using the unseeded XXH3_64bits() variant: As with all + * unseeded non-cryptographic hashes, it does not attempt to defend itself + * against specially crafted inputs, only random inputs. + * + * Compared to classic UMAC where a 1 in 2^31 chance of 4 consecutive bytes + * cancelling out the secret is taken an arbitrary number of times (addressed + * in XXH3_accumulate_512), this collision is very unlikely with random inputs + * and/or proper seeding: + * + * This only has a 1 in 2^63 chance of 8 consecutive bytes cancelling out, in a + * function that is only called up to 16 times per hash with up to 240 bytes of + * input. + * + * This is not too bad for a non-cryptographic hash function, especially with + * only 64 bit outputs. + * + * The 128-bit variant (which trades some speed for strength) is NOT affected + * by this, although it is always a good idea to use a proper seed if you care + * about strength. + */ +XXH_FORCE_INLINE xxh_u64 XXH3_mix16B(const xxh_u8* XXH_RESTRICT input, + const xxh_u8* XXH_RESTRICT secret, xxh_u64 seed64) +{ +#if defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__i386__) && defined(__SSE2__) /* x86 + SSE2 */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable like XXH32 hack */ + /* + * UGLY HACK: + * GCC for x86 tends to autovectorize the 128-bit multiply, resulting in + * slower code. + * + * By forcing seed64 into a register, we disrupt the cost model and + * cause it to scalarize. See `XXH32_round()` + * + * FIXME: Clang's output is still _much_ faster -- On an AMD Ryzen 3600, + * XXH3_64bits @ len=240 runs at 4.6 GB/s with Clang 9, but 3.3 GB/s on + * GCC 9.2, despite both emitting scalar code. + * + * GCC generates much better scalar code than Clang for the rest of XXH3, + * which is why finding a more optimal codepath is an interest. + */ + __asm__ ("" : "+r" (seed64)); +#endif + { xxh_u64 const input_lo = XXH_readLE64(input); + xxh_u64 const input_hi = XXH_readLE64(input+8); + return XXH3_mul128_fold64( + input_lo ^ (XXH_readLE64(secret) + seed64), + input_hi ^ (XXH_readLE64(secret+8) - seed64) + ); + } +} + +/* For mid range keys, XXH3 uses a Mum-hash variant. */ +XXH_FORCE_INLINE XXH64_hash_t +XXH3_len_17to128_64b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(16 < len && len <= 128); + + { xxh_u64 acc = len * XXH_PRIME64_1; + if (len > 32) { + if (len > 64) { + if (len > 96) { + acc += XXH3_mix16B(input+48, secret+96, seed); + acc += XXH3_mix16B(input+len-64, secret+112, seed); + } + acc += XXH3_mix16B(input+32, secret+64, seed); + acc += XXH3_mix16B(input+len-48, secret+80, seed); + } + acc += XXH3_mix16B(input+16, secret+32, seed); + acc += XXH3_mix16B(input+len-32, secret+48, seed); + } + acc += XXH3_mix16B(input+0, secret+0, seed); + acc += XXH3_mix16B(input+len-16, secret+16, seed); + + return XXH3_avalanche(acc); + } +} + +#define XXH3_MIDSIZE_MAX 240 + +XXH_NO_INLINE XXH64_hash_t +XXH3_len_129to240_64b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + + #define XXH3_MIDSIZE_STARTOFFSET 3 + #define XXH3_MIDSIZE_LASTOFFSET 17 + + { xxh_u64 acc = len * XXH_PRIME64_1; + int const nbRounds = (int)len / 16; + int i; + for (i=0; i<8; i++) { + acc += XXH3_mix16B(input+(16*i), secret+(16*i), seed); + } + acc = XXH3_avalanche(acc); + XXH_ASSERT(nbRounds >= 8); +#if defined(__clang__) /* Clang */ \ + && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */ + /* + * UGLY HACK: + * Clang for ARMv7-A tries to vectorize this loop, similar to GCC x86. + * In everywhere else, it uses scalar code. + * + * For 64->128-bit multiplies, even if the NEON was 100% optimal, it + * would still be slower than UMAAL (see XXH_mult64to128). + * + * Unfortunately, Clang doesn't handle the long multiplies properly and + * converts them to the nonexistent "vmulq_u64" intrinsic, which is then + * scalarized into an ugly mess of VMOV.32 instructions. + * + * This mess is difficult to avoid without turning autovectorization + * off completely, but they are usually relatively minor and/or not + * worth it to fix. + * + * This loop is the easiest to fix, as unlike XXH32, this pragma + * _actually works_ because it is a loop vectorization instead of an + * SLP vectorization. + */ + #pragma clang loop vectorize(disable) +#endif + for (i=8 ; i < nbRounds; i++) { + acc += XXH3_mix16B(input+(16*i), secret+(16*(i-8)) + XXH3_MIDSIZE_STARTOFFSET, seed); + } + /* last bytes */ + acc += XXH3_mix16B(input + len - 16, secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET, seed); + return XXH3_avalanche(acc); + } +} + + +/* ======= Long Keys ======= */ + +#define XXH_STRIPE_LEN 64 +#define XXH_SECRET_CONSUME_RATE 8 /* nb of secret bytes consumed at each accumulation */ +#define XXH_ACC_NB (XXH_STRIPE_LEN / sizeof(xxh_u64)) + +#ifdef XXH_OLD_NAMES +# define STRIPE_LEN XXH_STRIPE_LEN +# define ACC_NB XXH_ACC_NB +#endif + +XXH_FORCE_INLINE void XXH_writeLE64(void* dst, xxh_u64 v64) +{ + if (!XXH_CPU_LITTLE_ENDIAN) v64 = XXH_swap64(v64); + memcpy(dst, &v64, sizeof(v64)); +} + +/* Several intrinsic functions below are supposed to accept __int64 as argument, + * as documented in https://software.intel.com/sites/landingpage/IntrinsicsGuide/ . + * However, several environments do not define __int64 type, + * requiring a workaround. + */ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) + typedef int64_t xxh_i64; +#else + /* the following type must have a width of 64-bit */ + typedef long long xxh_i64; +#endif + +/* + * XXH3_accumulate_512 is the tightest loop for long inputs, and it is the most optimized. + * + * It is a hardened version of UMAC, based off of FARSH's implementation. + * + * This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD + * implementations, and it is ridiculously fast. + * + * We harden it by mixing the original input to the accumulators as well as the product. + * + * This means that in the (relatively likely) case of a multiply by zero, the + * original input is preserved. + * + * On 128-bit inputs, we swap 64-bit pairs when we add the input to improve + * cross-pollination, as otherwise the upper and lower halves would be + * essentially independent. + * + * This doesn't matter on 64-bit hashes since they all get merged together in + * the end, so we skip the extra step. + * + * Both XXH3_64bits and XXH3_128bits use this subroutine. + */ + +#if (XXH_VECTOR == XXH_AVX512) || defined(XXH_X86DISPATCH) + +#ifndef XXH_TARGET_AVX512 +# define XXH_TARGET_AVX512 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_accumulate_512_avx512(void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ALIGN(64) __m512i* const xacc = (__m512i *) acc; + XXH_ASSERT((((size_t)acc) & 63) == 0); + XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i)); + + { + /* data_vec = input[0]; */ + __m512i const data_vec = _mm512_loadu_si512 (input); + /* key_vec = secret[0]; */ + __m512i const key_vec = _mm512_loadu_si512 (secret); + /* data_key = data_vec ^ key_vec; */ + __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m512i const data_key_lo = _mm512_shuffle_epi32 (data_key, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 3, 0, 1)); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m512i const product = _mm512_mul_epu32 (data_key, data_key_lo); + /* xacc[0] += swap(data_vec); */ + __m512i const data_swap = _mm512_shuffle_epi32(data_vec, (_MM_PERM_ENUM)_MM_SHUFFLE(1, 0, 3, 2)); + __m512i const sum = _mm512_add_epi64(*xacc, data_swap); + /* xacc[0] += product; */ + *xacc = _mm512_add_epi64(product, sum); + } +} + +/* + * XXH3_scrambleAcc: Scrambles the accumulators to improve mixing. + * + * Multiplication isn't perfect, as explained by Google in HighwayHash: + * + * // Multiplication mixes/scrambles bytes 0-7 of the 64-bit result to + * // varying degrees. In descending order of goodness, bytes + * // 3 4 2 5 1 6 0 7 have quality 228 224 164 160 100 96 36 32. + * // As expected, the upper and lower bytes are much worse. + * + * Source: https://github.com/google/highwayhash/blob/0aaf66b/highwayhash/hh_avx2.h#L291 + * + * Since our algorithm uses a pseudorandom secret to add some variance into the + * mix, we don't need to (or want to) mix as often or as much as HighwayHash does. + * + * This isn't as tight as XXH3_accumulate, but still written in SIMD to avoid + * extraction. + * + * Both XXH3_64bits and XXH3_128bits use this subroutine. + */ + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_scrambleAcc_avx512(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 63) == 0); + XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i)); + { XXH_ALIGN(64) __m512i* const xacc = (__m512i*) acc; + const __m512i prime32 = _mm512_set1_epi32((int)XXH_PRIME32_1); + + /* xacc[0] ^= (xacc[0] >> 47) */ + __m512i const acc_vec = *xacc; + __m512i const shifted = _mm512_srli_epi64 (acc_vec, 47); + __m512i const data_vec = _mm512_xor_si512 (acc_vec, shifted); + /* xacc[0] ^= secret; */ + __m512i const key_vec = _mm512_loadu_si512 (secret); + __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec); + + /* xacc[0] *= XXH_PRIME32_1; */ + __m512i const data_key_hi = _mm512_shuffle_epi32 (data_key, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 3, 0, 1)); + __m512i const prod_lo = _mm512_mul_epu32 (data_key, prime32); + __m512i const prod_hi = _mm512_mul_epu32 (data_key_hi, prime32); + *xacc = _mm512_add_epi64(prod_lo, _mm512_slli_epi64(prod_hi, 32)); + } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_initCustomSecret_avx512(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 63) == 0); + XXH_STATIC_ASSERT(XXH_SEC_ALIGN == 64); + XXH_ASSERT(((size_t)customSecret & 63) == 0); + (void)(&XXH_writeLE64); + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m512i); + __m512i const seed = _mm512_mask_set1_epi64(_mm512_set1_epi64((xxh_i64)seed64), 0xAA, -(xxh_i64)seed64); + + XXH_ALIGN(64) const __m512i* const src = (const __m512i*) XXH3_kSecret; + XXH_ALIGN(64) __m512i* const dest = ( __m512i*) customSecret; + int i; + for (i=0; i < nbRounds; ++i) { + /* GCC has a bug, _mm512_stream_load_si512 accepts 'void*', not 'void const*', + * this will warn "discards ‘const’ qualifier". */ + union { + XXH_ALIGN(64) const __m512i* cp; + XXH_ALIGN(64) void* p; + } remote_const_void; + remote_const_void.cp = src + i; + dest[i] = _mm512_add_epi64(_mm512_stream_load_si512(remote_const_void.p), seed); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_AVX2) || defined(XXH_X86DISPATCH) + +#ifndef XXH_TARGET_AVX2 +# define XXH_TARGET_AVX2 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void +XXH3_accumulate_512_avx2( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 31) == 0); + { XXH_ALIGN(32) __m256i* const xacc = (__m256i *) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xinput = (const __m256i *) input; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xsecret = (const __m256i *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) { + /* data_vec = xinput[i]; */ + __m256i const data_vec = _mm256_loadu_si256 (xinput+i); + /* key_vec = xsecret[i]; */ + __m256i const key_vec = _mm256_loadu_si256 (xsecret+i); + /* data_key = data_vec ^ key_vec; */ + __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m256i const data_key_lo = _mm256_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m256i const product = _mm256_mul_epu32 (data_key, data_key_lo); + /* xacc[i] += swap(data_vec); */ + __m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i const sum = _mm256_add_epi64(xacc[i], data_swap); + /* xacc[i] += product; */ + xacc[i] = _mm256_add_epi64(product, sum); + } } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void +XXH3_scrambleAcc_avx2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 31) == 0); + { XXH_ALIGN(32) __m256i* const xacc = (__m256i*) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xsecret = (const __m256i *) secret; + const __m256i prime32 = _mm256_set1_epi32((int)XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) { + /* xacc[i] ^= (xacc[i] >> 47) */ + __m256i const acc_vec = xacc[i]; + __m256i const shifted = _mm256_srli_epi64 (acc_vec, 47); + __m256i const data_vec = _mm256_xor_si256 (acc_vec, shifted); + /* xacc[i] ^= xsecret; */ + __m256i const key_vec = _mm256_loadu_si256 (xsecret+i); + __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); + + /* xacc[i] *= XXH_PRIME32_1; */ + __m256i const data_key_hi = _mm256_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i const prod_lo = _mm256_mul_epu32 (data_key, prime32); + __m256i const prod_hi = _mm256_mul_epu32 (data_key_hi, prime32); + xacc[i] = _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32)); + } + } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void XXH3_initCustomSecret_avx2(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 31) == 0); + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE / sizeof(__m256i)) == 6); + XXH_STATIC_ASSERT(XXH_SEC_ALIGN <= 64); + (void)(&XXH_writeLE64); + XXH_PREFETCH(customSecret); + { __m256i const seed = _mm256_set_epi64x(-(xxh_i64)seed64, (xxh_i64)seed64, -(xxh_i64)seed64, (xxh_i64)seed64); + + XXH_ALIGN(64) const __m256i* const src = (const __m256i*) XXH3_kSecret; + XXH_ALIGN(64) __m256i* dest = ( __m256i*) customSecret; + +# if defined(__GNUC__) || defined(__clang__) + /* + * On GCC & Clang, marking 'dest' as modified will cause the compiler: + * - do not extract the secret from sse registers in the internal loop + * - use less common registers, and avoid pushing these reg into stack + * The asm hack causes Clang to assume that XXH3_kSecretPtr aliases with + * customSecret, and on aarch64, this prevented LDP from merging two + * loads together for free. Putting the loads together before the stores + * properly generates LDP. + */ + __asm__("" : "+r" (dest)); +# endif + + /* GCC -O2 need unroll loop manually */ + dest[0] = _mm256_add_epi64(_mm256_stream_load_si256(src+0), seed); + dest[1] = _mm256_add_epi64(_mm256_stream_load_si256(src+1), seed); + dest[2] = _mm256_add_epi64(_mm256_stream_load_si256(src+2), seed); + dest[3] = _mm256_add_epi64(_mm256_stream_load_si256(src+3), seed); + dest[4] = _mm256_add_epi64(_mm256_stream_load_si256(src+4), seed); + dest[5] = _mm256_add_epi64(_mm256_stream_load_si256(src+5), seed); + } +} + +#endif + +#if (XXH_VECTOR == XXH_SSE2) || defined(XXH_X86DISPATCH) + +#ifndef XXH_TARGET_SSE2 +# define XXH_TARGET_SSE2 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void +XXH3_accumulate_512_sse2( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + /* SSE2 is just a half-scale version of the AVX2 version. */ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { XXH_ALIGN(16) __m128i* const xacc = (__m128i *) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xinput = (const __m128i *) input; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xsecret = (const __m128i *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) { + /* data_vec = xinput[i]; */ + __m128i const data_vec = _mm_loadu_si128 (xinput+i); + /* key_vec = xsecret[i]; */ + __m128i const key_vec = _mm_loadu_si128 (xsecret+i); + /* data_key = data_vec ^ key_vec; */ + __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m128i const data_key_lo = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m128i const product = _mm_mul_epu32 (data_key, data_key_lo); + /* xacc[i] += swap(data_vec); */ + __m128i const data_swap = _mm_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2)); + __m128i const sum = _mm_add_epi64(xacc[i], data_swap); + /* xacc[i] += product; */ + xacc[i] = _mm_add_epi64(product, sum); + } } +} + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void +XXH3_scrambleAcc_sse2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { XXH_ALIGN(16) __m128i* const xacc = (__m128i*) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xsecret = (const __m128i *) secret; + const __m128i prime32 = _mm_set1_epi32((int)XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) { + /* xacc[i] ^= (xacc[i] >> 47) */ + __m128i const acc_vec = xacc[i]; + __m128i const shifted = _mm_srli_epi64 (acc_vec, 47); + __m128i const data_vec = _mm_xor_si128 (acc_vec, shifted); + /* xacc[i] ^= xsecret[i]; */ + __m128i const key_vec = _mm_loadu_si128 (xsecret+i); + __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); + + /* xacc[i] *= XXH_PRIME32_1; */ + __m128i const data_key_hi = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + __m128i const prod_lo = _mm_mul_epu32 (data_key, prime32); + __m128i const prod_hi = _mm_mul_epu32 (data_key_hi, prime32); + xacc[i] = _mm_add_epi64(prod_lo, _mm_slli_epi64(prod_hi, 32)); + } + } +} + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void XXH3_initCustomSecret_sse2(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0); + (void)(&XXH_writeLE64); + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m128i); + +# if defined(_MSC_VER) && defined(_M_IX86) && _MSC_VER < 1900 + // MSVC 32bit mode does not support _mm_set_epi64x before 2015 + XXH_ALIGN(16) const xxh_i64 seed64x2[2] = { (xxh_i64)seed64, -(xxh_i64)seed64 }; + __m128i const seed = _mm_load_si128((__m128i const*)seed64x2); +# else + __m128i const seed = _mm_set_epi64x(-(xxh_i64)seed64, (xxh_i64)seed64); +# endif + int i; + + XXH_ALIGN(64) const float* const src = (float const*) XXH3_kSecret; + XXH_ALIGN(XXH_SEC_ALIGN) __m128i* dest = (__m128i*) customSecret; +# if defined(__GNUC__) || defined(__clang__) + /* + * On GCC & Clang, marking 'dest' as modified will cause the compiler: + * - do not extract the secret from sse registers in the internal loop + * - use less common registers, and avoid pushing these reg into stack + */ + __asm__("" : "+r" (dest)); +# endif + + for (i=0; i < nbRounds; ++i) { + dest[i] = _mm_add_epi64(_mm_castps_si128(_mm_load_ps(src+i*4)), seed); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_NEON) + +XXH_FORCE_INLINE void +XXH3_accumulate_512_neon( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { + XXH_ALIGN(16) uint64x2_t* const xacc = (uint64x2_t *) acc; + /* We don't use a uint32x4_t pointer because it causes bus errors on ARMv7. */ + uint8_t const* const xinput = (const uint8_t *) input; + uint8_t const* const xsecret = (const uint8_t *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN / sizeof(uint64x2_t); i++) { + /* data_vec = xinput[i]; */ + uint8x16_t data_vec = vld1q_u8(xinput + (i * 16)); + /* key_vec = xsecret[i]; */ + uint8x16_t key_vec = vld1q_u8(xsecret + (i * 16)); + uint64x2_t data_key; + uint32x2_t data_key_lo, data_key_hi; + /* xacc[i] += swap(data_vec); */ + uint64x2_t const data64 = vreinterpretq_u64_u8(data_vec); + uint64x2_t const swapped = vextq_u64(data64, data64, 1); + xacc[i] = vaddq_u64 (xacc[i], swapped); + /* data_key = data_vec ^ key_vec; */ + data_key = vreinterpretq_u64_u8(veorq_u8(data_vec, key_vec)); + /* data_key_lo = (uint32x2_t) (data_key & 0xFFFFFFFF); + * data_key_hi = (uint32x2_t) (data_key >> 32); + * data_key = UNDEFINED; */ + XXH_SPLIT_IN_PLACE(data_key, data_key_lo, data_key_hi); + /* xacc[i] += (uint64x2_t) data_key_lo * (uint64x2_t) data_key_hi; */ + xacc[i] = vmlal_u32 (xacc[i], data_key_lo, data_key_hi); + + } + } +} + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_neon(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + + { uint64x2_t* xacc = (uint64x2_t*) acc; + uint8_t const* xsecret = (uint8_t const*) secret; + uint32x2_t prime = vdup_n_u32 (XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(uint64x2_t); i++) { + /* xacc[i] ^= (xacc[i] >> 47); */ + uint64x2_t acc_vec = xacc[i]; + uint64x2_t shifted = vshrq_n_u64 (acc_vec, 47); + uint64x2_t data_vec = veorq_u64 (acc_vec, shifted); + + /* xacc[i] ^= xsecret[i]; */ + uint8x16_t key_vec = vld1q_u8(xsecret + (i * 16)); + uint64x2_t data_key = veorq_u64(data_vec, vreinterpretq_u64_u8(key_vec)); + + /* xacc[i] *= XXH_PRIME32_1 */ + uint32x2_t data_key_lo, data_key_hi; + /* data_key_lo = (uint32x2_t) (xacc[i] & 0xFFFFFFFF); + * data_key_hi = (uint32x2_t) (xacc[i] >> 32); + * xacc[i] = UNDEFINED; */ + XXH_SPLIT_IN_PLACE(data_key, data_key_lo, data_key_hi); + { /* + * prod_hi = (data_key >> 32) * XXH_PRIME32_1; + * + * Avoid vmul_u32 + vshll_n_u32 since Clang 6 and 7 will + * incorrectly "optimize" this: + * tmp = vmul_u32(vmovn_u64(a), vmovn_u64(b)); + * shifted = vshll_n_u32(tmp, 32); + * to this: + * tmp = "vmulq_u64"(a, b); // no such thing! + * shifted = vshlq_n_u64(tmp, 32); + * + * However, unlike SSE, Clang lacks a 64-bit multiply routine + * for NEON, and it scalarizes two 64-bit multiplies instead. + * + * vmull_u32 has the same timing as vmul_u32, and it avoids + * this bug completely. + * See https://bugs.llvm.org/show_bug.cgi?id=39967 + */ + uint64x2_t prod_hi = vmull_u32 (data_key_hi, prime); + /* xacc[i] = prod_hi << 32; */ + xacc[i] = vshlq_n_u64(prod_hi, 32); + /* xacc[i] += (prod_hi & 0xFFFFFFFF) * XXH_PRIME32_1; */ + xacc[i] = vmlal_u32(xacc[i], data_key_lo, prime); + } + } } +} + +#endif + +#if (XXH_VECTOR == XXH_VSX) + +XXH_FORCE_INLINE void +XXH3_accumulate_512_vsx( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + xxh_u64x2* const xacc = (xxh_u64x2*) acc; /* presumed aligned */ + xxh_u64x2 const* const xinput = (xxh_u64x2 const*) input; /* no alignment restriction */ + xxh_u64x2 const* const xsecret = (xxh_u64x2 const*) secret; /* no alignment restriction */ + xxh_u64x2 const v32 = { 32, 32 }; + size_t i; + for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) { + /* data_vec = xinput[i]; */ + xxh_u64x2 const data_vec = XXH_vec_loadu(xinput + i); + /* key_vec = xsecret[i]; */ + xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + i); + xxh_u64x2 const data_key = data_vec ^ key_vec; + /* shuffled = (data_key << 32) | (data_key >> 32); */ + xxh_u32x4 const shuffled = (xxh_u32x4)vec_rl(data_key, v32); + /* product = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)shuffled & 0xFFFFFFFF); */ + xxh_u64x2 const product = XXH_vec_mulo((xxh_u32x4)data_key, shuffled); + xacc[i] += product; + + /* swap high and low halves */ +#ifdef __s390x__ + xxh_u64x2 const data_swapped = vec_permi(data_vec, data_vec, 2); +#else + xxh_u64x2 const data_swapped = vec_xxpermdi(data_vec, data_vec, 2); +#endif + xacc[i] += data_swapped; + } +} + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_vsx(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + + { xxh_u64x2* const xacc = (xxh_u64x2*) acc; + const xxh_u64x2* const xsecret = (const xxh_u64x2*) secret; + /* constants */ + xxh_u64x2 const v32 = { 32, 32 }; + xxh_u64x2 const v47 = { 47, 47 }; + xxh_u32x4 const prime = { XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1 }; + size_t i; + for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) { + /* xacc[i] ^= (xacc[i] >> 47); */ + xxh_u64x2 const acc_vec = xacc[i]; + xxh_u64x2 const data_vec = acc_vec ^ (acc_vec >> v47); + + /* xacc[i] ^= xsecret[i]; */ + xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + i); + xxh_u64x2 const data_key = data_vec ^ key_vec; + + /* xacc[i] *= XXH_PRIME32_1 */ + /* prod_lo = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)prime & 0xFFFFFFFF); */ + xxh_u64x2 const prod_even = XXH_vec_mule((xxh_u32x4)data_key, prime); + /* prod_hi = ((xxh_u64x2)data_key >> 32) * ((xxh_u64x2)prime >> 32); */ + xxh_u64x2 const prod_odd = XXH_vec_mulo((xxh_u32x4)data_key, prime); + xacc[i] = prod_odd + (prod_even << v32); + } } +} + +#endif + +/* scalar variants - universal */ + +XXH_FORCE_INLINE void +XXH3_accumulate_512_scalar(void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */ + const xxh_u8* const xinput = (const xxh_u8*) input; /* no alignment restriction */ + const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */ + size_t i; + XXH_ASSERT(((size_t)acc & (XXH_ACC_ALIGN-1)) == 0); + for (i=0; i < XXH_ACC_NB; i++) { + xxh_u64 const data_val = XXH_readLE64(xinput + 8*i); + xxh_u64 const data_key = data_val ^ XXH_readLE64(xsecret + i*8); + xacc[i ^ 1] += data_val; /* swap adjacent lanes */ + xacc[i] += XXH_mult32to64(data_key & 0xFFFFFFFF, data_key >> 32); + } +} + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_scalar(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */ + const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */ + size_t i; + XXH_ASSERT((((size_t)acc) & (XXH_ACC_ALIGN-1)) == 0); + for (i=0; i < XXH_ACC_NB; i++) { + xxh_u64 const key64 = XXH_readLE64(xsecret + 8*i); + xxh_u64 acc64 = xacc[i]; + acc64 = XXH_xorshift64(acc64, 47); + acc64 ^= key64; + acc64 *= XXH_PRIME32_1; + xacc[i] = acc64; + } +} + +XXH_FORCE_INLINE void +XXH3_initCustomSecret_scalar(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + /* + * We need a separate pointer for the hack below, + * which requires a non-const pointer. + * Any decent compiler will optimize this out otherwise. + */ + const xxh_u8* kSecretPtr = XXH3_kSecret; + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0); + +#if defined(__clang__) && defined(__aarch64__) + /* + * UGLY HACK: + * Clang generates a bunch of MOV/MOVK pairs for aarch64, and they are + * placed sequentially, in order, at the top of the unrolled loop. + * + * While MOVK is great for generating constants (2 cycles for a 64-bit + * constant compared to 4 cycles for LDR), long MOVK chains stall the + * integer pipelines: + * I L S + * MOVK + * MOVK + * MOVK + * MOVK + * ADD + * SUB STR + * STR + * By forcing loads from memory (as the asm line causes Clang to assume + * that XXH3_kSecretPtr has been changed), the pipelines are used more + * efficiently: + * I L S + * LDR + * ADD LDR + * SUB STR + * STR + * XXH3_64bits_withSeed, len == 256, Snapdragon 835 + * without hack: 2654.4 MB/s + * with hack: 3202.9 MB/s + */ + __asm__("" : "+r" (kSecretPtr)); +#endif + /* + * Note: in debug mode, this overrides the asm optimization + * and Clang will emit MOVK chains again. + */ + XXH_ASSERT(kSecretPtr == XXH3_kSecret); + + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / 16; + int i; + for (i=0; i < nbRounds; i++) { + /* + * The asm hack causes Clang to assume that kSecretPtr aliases with + * customSecret, and on aarch64, this prevented LDP from merging two + * loads together for free. Putting the loads together before the stores + * properly generates LDP. + */ + xxh_u64 lo = XXH_readLE64(kSecretPtr + 16*i) + seed64; + xxh_u64 hi = XXH_readLE64(kSecretPtr + 16*i + 8) - seed64; + XXH_writeLE64((xxh_u8*)customSecret + 16*i, lo); + XXH_writeLE64((xxh_u8*)customSecret + 16*i + 8, hi); + } } +} + + +typedef void (*XXH3_f_accumulate_512)(void* XXH_RESTRICT, const void*, const void*); +typedef void (*XXH3_f_scrambleAcc)(void* XXH_RESTRICT, const void*); +typedef void (*XXH3_f_initCustomSecret)(void* XXH_RESTRICT, xxh_u64); + + +#if (XXH_VECTOR == XXH_AVX512) + +#define XXH3_accumulate_512 XXH3_accumulate_512_avx512 +#define XXH3_scrambleAcc XXH3_scrambleAcc_avx512 +#define XXH3_initCustomSecret XXH3_initCustomSecret_avx512 + +#elif (XXH_VECTOR == XXH_AVX2) + +#define XXH3_accumulate_512 XXH3_accumulate_512_avx2 +#define XXH3_scrambleAcc XXH3_scrambleAcc_avx2 +#define XXH3_initCustomSecret XXH3_initCustomSecret_avx2 + +#elif (XXH_VECTOR == XXH_SSE2) + +#define XXH3_accumulate_512 XXH3_accumulate_512_sse2 +#define XXH3_scrambleAcc XXH3_scrambleAcc_sse2 +#define XXH3_initCustomSecret XXH3_initCustomSecret_sse2 + +#elif (XXH_VECTOR == XXH_NEON) + +#define XXH3_accumulate_512 XXH3_accumulate_512_neon +#define XXH3_scrambleAcc XXH3_scrambleAcc_neon +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#elif (XXH_VECTOR == XXH_VSX) + +#define XXH3_accumulate_512 XXH3_accumulate_512_vsx +#define XXH3_scrambleAcc XXH3_scrambleAcc_vsx +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#else /* scalar */ + +#define XXH3_accumulate_512 XXH3_accumulate_512_scalar +#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#endif + + + +#ifndef XXH_PREFETCH_DIST +# ifdef __clang__ +# define XXH_PREFETCH_DIST 320 +# else +# if (XXH_VECTOR == XXH_AVX512) +# define XXH_PREFETCH_DIST 512 +# else +# define XXH_PREFETCH_DIST 384 +# endif +# endif /* __clang__ */ +#endif /* XXH_PREFETCH_DIST */ + +/* + * XXH3_accumulate() + * Loops over XXH3_accumulate_512(). + * Assumption: nbStripes will not overflow the secret size + */ +XXH_FORCE_INLINE void +XXH3_accumulate( xxh_u64* XXH_RESTRICT acc, + const xxh_u8* XXH_RESTRICT input, + const xxh_u8* XXH_RESTRICT secret, + size_t nbStripes, + XXH3_f_accumulate_512 f_acc512) +{ + size_t n; + for (n = 0; n < nbStripes; n++ ) { + const xxh_u8* const in = input + n*XXH_STRIPE_LEN; + XXH_PREFETCH(in + XXH_PREFETCH_DIST); + f_acc512(acc, + in, + secret + n*XXH_SECRET_CONSUME_RATE); + } +} + +XXH_FORCE_INLINE void +XXH3_hashLong_internal_loop(xxh_u64* XXH_RESTRICT acc, + const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble) +{ + size_t const nbStripesPerBlock = (secretSize - XXH_STRIPE_LEN) / XXH_SECRET_CONSUME_RATE; + size_t const block_len = XXH_STRIPE_LEN * nbStripesPerBlock; + size_t const nb_blocks = (len - 1) / block_len; + + size_t n; + + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); + + for (n = 0; n < nb_blocks; n++) { + XXH3_accumulate(acc, input + n*block_len, secret, nbStripesPerBlock, f_acc512); + f_scramble(acc, secret + secretSize - XXH_STRIPE_LEN); + } + + /* last partial block */ + XXH_ASSERT(len > XXH_STRIPE_LEN); + { size_t const nbStripes = ((len - 1) - (block_len * nb_blocks)) / XXH_STRIPE_LEN; + XXH_ASSERT(nbStripes <= (secretSize / XXH_SECRET_CONSUME_RATE)); + XXH3_accumulate(acc, input + nb_blocks*block_len, secret, nbStripes, f_acc512); + + /* last stripe */ + { const xxh_u8* const p = input + len - XXH_STRIPE_LEN; +#define XXH_SECRET_LASTACC_START 7 /* not aligned on 8, last secret is different from acc & scrambler */ + f_acc512(acc, p, secret + secretSize - XXH_STRIPE_LEN - XXH_SECRET_LASTACC_START); + } } +} + +XXH_FORCE_INLINE xxh_u64 +XXH3_mix2Accs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret) +{ + return XXH3_mul128_fold64( + acc[0] ^ XXH_readLE64(secret), + acc[1] ^ XXH_readLE64(secret+8) ); +} + +static XXH64_hash_t +XXH3_mergeAccs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret, xxh_u64 start) +{ + xxh_u64 result64 = start; + size_t i = 0; + + for (i = 0; i < 4; i++) { + result64 += XXH3_mix2Accs(acc+2*i, secret + 16*i); +#if defined(__clang__) /* Clang */ \ + && (defined(__arm__) || defined(__thumb__)) /* ARMv7 */ \ + && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */ + /* + * UGLY HACK: + * Prevent autovectorization on Clang ARMv7-a. Exact same problem as + * the one in XXH3_len_129to240_64b. Speeds up shorter keys > 240b. + * XXH3_64bits, len == 256, Snapdragon 835: + * without hack: 2063.7 MB/s + * with hack: 2560.7 MB/s + */ + __asm__("" : "+r" (result64)); +#endif + } + + return XXH3_avalanche(result64); +} + +#define XXH3_INIT_ACC { XXH_PRIME32_3, XXH_PRIME64_1, XXH_PRIME64_2, XXH_PRIME64_3, \ + XXH_PRIME64_4, XXH_PRIME32_2, XXH_PRIME64_5, XXH_PRIME32_1 } + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_hashLong_64b_internal(const void* XXH_RESTRICT input, size_t len, + const void* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC; + + XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, f_acc512, f_scramble); + + /* converge into final hash */ + XXH_STATIC_ASSERT(sizeof(acc) == 64); + /* do not align on 8, so that the secret is different from the accumulator */ +#define XXH_SECRET_MERGEACCS_START 11 + XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + return XXH3_mergeAccs(acc, (const xxh_u8*)secret + XXH_SECRET_MERGEACCS_START, (xxh_u64)len * XXH_PRIME64_1); +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSecret(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; + return XXH3_hashLong_64b_internal(input, len, secret, secretLen, XXH3_accumulate_512, XXH3_scrambleAcc); +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + * Since the function is not inlined, the compiler may not be able to understand that, + * in some scenarios, its `secret` argument is actually a compile time constant. + * This variant enforces that the compiler can detect that, + * and uses this opportunity to streamline the generated code for better performance. + */ +XXH_NO_INLINE XXH64_hash_t +XXH3_hashLong_64b_default(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; (void)secret; (void)secretLen; + return XXH3_hashLong_64b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_accumulate_512, XXH3_scrambleAcc); +} + +/* + * XXH3_hashLong_64b_withSeed(): + * Generate a custom key based on alteration of default XXH3_kSecret with the seed, + * and then use this key for long mode hashing. + * + * This operation is decently fast but nonetheless costs a little bit of time. + * Try to avoid it whenever possible (typically when seed==0). + * + * It's important for performance that XXH3_hashLong is not inlined. Not sure + * why (uop cache maybe?), but the difference is large and easily measurable. + */ +XXH_FORCE_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSeed_internal(const void* input, size_t len, + XXH64_hash_t seed, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble, + XXH3_f_initCustomSecret f_initSec) +{ + if (seed == 0) + return XXH3_hashLong_64b_internal(input, len, + XXH3_kSecret, sizeof(XXH3_kSecret), + f_acc512, f_scramble); + { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + f_initSec(secret, seed); + return XXH3_hashLong_64b_internal(input, len, secret, sizeof(secret), + f_acc512, f_scramble); + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSeed(const void* input, size_t len, + XXH64_hash_t seed, const xxh_u8* secret, size_t secretLen) +{ + (void)secret; (void)secretLen; + return XXH3_hashLong_64b_withSeed_internal(input, len, seed, + XXH3_accumulate_512, XXH3_scrambleAcc, XXH3_initCustomSecret); +} + + +typedef XXH64_hash_t (*XXH3_hashLong64_f)(const void* XXH_RESTRICT, size_t, + XXH64_hash_t, const xxh_u8* XXH_RESTRICT, size_t); + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_64bits_internal(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen, + XXH3_hashLong64_f f_hashLong) +{ + XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN); + /* + * If an action is to be taken if `secretLen` condition is not respected, + * it should be done here. + * For now, it's a contract pre-condition. + * Adding a check and a branch here would cost performance at every hash. + * Also, note that function signature doesn't offer room to return an error. + */ + if (len <= 16) + return XXH3_len_0to16_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64); + if (len <= 128) + return XXH3_len_17to128_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_len_129to240_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + return f_hashLong(input, len, seed64, (const xxh_u8*)secret, secretLen); +} + + +/* === Public entry point === */ + +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(const void* input, size_t len) +{ + return XXH3_64bits_internal(input, len, 0, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_default); +} + +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize) +{ + return XXH3_64bits_internal(input, len, 0, secret, secretSize, XXH3_hashLong_64b_withSecret); +} + +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSeed(const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_64bits_internal(input, len, seed, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_withSeed); +} + + +/* === XXH3 streaming === */ + +/* + * Malloc's a pointer that is always aligned to align. + * + * This must be freed with `XXH_alignedFree()`. + * + * malloc typically guarantees 16 byte alignment on 64-bit systems and 8 byte + * alignment on 32-bit. This isn't enough for the 32 byte aligned loads in AVX2 + * or on 32-bit, the 16 byte aligned loads in SSE2 and NEON. + * + * This underalignment previously caused a rather obvious crash which went + * completely unnoticed due to XXH3_createState() not actually being tested. + * Credit to RedSpah for noticing this bug. + * + * The alignment is done manually: Functions like posix_memalign or _mm_malloc + * are avoided: To maintain portability, we would have to write a fallback + * like this anyways, and besides, testing for the existence of library + * functions without relying on external build tools is impossible. + * + * The method is simple: Overallocate, manually align, and store the offset + * to the original behind the returned pointer. + * + * Align must be a power of 2 and 8 <= align <= 128. + */ +static void* XXH_alignedMalloc(size_t s, size_t align) +{ + XXH_ASSERT(align <= 128 && align >= 8); /* range check */ + XXH_ASSERT((align & (align-1)) == 0); /* power of 2 */ + XXH_ASSERT(s != 0 && s < (s + align)); /* empty/overflow */ + { /* Overallocate to make room for manual realignment and an offset byte */ + xxh_u8* base = (xxh_u8*)XXH_malloc(s + align); + if (base != NULL) { + /* + * Get the offset needed to align this pointer. + * + * Even if the returned pointer is aligned, there will always be + * at least one byte to store the offset to the original pointer. + */ + size_t offset = align - ((size_t)base & (align - 1)); /* base % align */ + /* Add the offset for the now-aligned pointer */ + xxh_u8* ptr = base + offset; + + XXH_ASSERT((size_t)ptr % align == 0); + + /* Store the offset immediately before the returned pointer. */ + ptr[-1] = (xxh_u8)offset; + return ptr; + } + return NULL; + } +} +/* + * Frees an aligned pointer allocated by XXH_alignedMalloc(). Don't pass + * normal malloc'd pointers, XXH_alignedMalloc has a specific data layout. + */ +static void XXH_alignedFree(void* p) +{ + if (p != NULL) { + xxh_u8* ptr = (xxh_u8*)p; + /* Get the offset byte we added in XXH_malloc. */ + xxh_u8 offset = ptr[-1]; + /* Free the original malloc'd pointer */ + xxh_u8* base = ptr - offset; + XXH_free(base); + } +} +XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void) +{ + return (XXH3_state_t*)XXH_alignedMalloc(sizeof(XXH3_state_t), 64); +} + +XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr) +{ + XXH_alignedFree(statePtr); + return XXH_OK; +} + +XXH_PUBLIC_API void +XXH3_copyState(XXH3_state_t* dst_state, const XXH3_state_t* src_state) +{ + memcpy(dst_state, src_state, sizeof(*dst_state)); +} + +static void +XXH3_64bits_reset_internal(XXH3_state_t* statePtr, + XXH64_hash_t seed, + const xxh_u8* secret, size_t secretSize) +{ + XXH_ASSERT(statePtr != NULL); + memset(statePtr, 0, sizeof(*statePtr)); + statePtr->acc[0] = XXH_PRIME32_3; + statePtr->acc[1] = XXH_PRIME64_1; + statePtr->acc[2] = XXH_PRIME64_2; + statePtr->acc[3] = XXH_PRIME64_3; + statePtr->acc[4] = XXH_PRIME64_4; + statePtr->acc[5] = XXH_PRIME32_2; + statePtr->acc[6] = XXH_PRIME64_5; + statePtr->acc[7] = XXH_PRIME32_1; + statePtr->seed = seed; + XXH_ASSERT(secret != NULL); + statePtr->extSecret = secret; + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); + statePtr->secretLimit = secretSize - XXH_STRIPE_LEN; + statePtr->nbStripesPerBlock = statePtr->secretLimit / XXH_SECRET_CONSUME_RATE; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset(XXH3_state_t* statePtr) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_64bits_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_64bits_reset_internal(statePtr, 0, (const xxh_u8*)secret, secretSize); + if (secret == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_64bits_reset_internal(statePtr, seed, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + XXH3_initCustomSecret(statePtr->customSecret, seed); + statePtr->extSecret = NULL; + return XXH_OK; +} + +/* Note : when XXH3_consumeStripes() is invoked, + * there must be a guarantee that at least one more byte must be consumed from input + * so that the function can blindly consume all stripes using the "normal" secret segment */ +XXH_FORCE_INLINE void +XXH3_consumeStripes(xxh_u64* XXH_RESTRICT acc, + size_t* XXH_RESTRICT nbStripesSoFarPtr, size_t nbStripesPerBlock, + const xxh_u8* XXH_RESTRICT input, size_t nbStripes, + const xxh_u8* XXH_RESTRICT secret, size_t secretLimit, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ASSERT(nbStripes <= nbStripesPerBlock); /* can handle max 1 scramble per invocation */ + XXH_ASSERT(*nbStripesSoFarPtr < nbStripesPerBlock); + if (nbStripesPerBlock - *nbStripesSoFarPtr <= nbStripes) { + /* need a scrambling operation */ + size_t const nbStripesToEndofBlock = nbStripesPerBlock - *nbStripesSoFarPtr; + size_t const nbStripesAfterBlock = nbStripes - nbStripesToEndofBlock; + XXH3_accumulate(acc, input, secret + nbStripesSoFarPtr[0] * XXH_SECRET_CONSUME_RATE, nbStripesToEndofBlock, f_acc512); + f_scramble(acc, secret + secretLimit); + XXH3_accumulate(acc, input + nbStripesToEndofBlock * XXH_STRIPE_LEN, secret, nbStripesAfterBlock, f_acc512); + *nbStripesSoFarPtr = nbStripesAfterBlock; + } else { + XXH3_accumulate(acc, input, secret + nbStripesSoFarPtr[0] * XXH_SECRET_CONSUME_RATE, nbStripes, f_acc512); + *nbStripesSoFarPtr += nbStripes; + } +} + +/* + * Both XXH3_64bits_update and XXH3_128bits_update use this routine. + */ +XXH_FORCE_INLINE XXH_errorcode +XXH3_update(XXH3_state_t* state, + const xxh_u8* input, size_t len, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble) +{ + if (input==NULL) +#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1) + return XXH_OK; +#else + return XXH_ERROR; +#endif + + { const xxh_u8* const bEnd = input + len; + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + + state->totalLen += len; + + if (state->bufferedSize + len <= XXH3_INTERNALBUFFER_SIZE) { /* fill in tmp buffer */ + XXH_memcpy(state->buffer + state->bufferedSize, input, len); + state->bufferedSize += (XXH32_hash_t)len; + return XXH_OK; + } + /* total input is now > XXH3_INTERNALBUFFER_SIZE */ + + #define XXH3_INTERNALBUFFER_STRIPES (XXH3_INTERNALBUFFER_SIZE / XXH_STRIPE_LEN) + XXH_STATIC_ASSERT(XXH3_INTERNALBUFFER_SIZE % XXH_STRIPE_LEN == 0); /* clean multiple */ + + /* + * Internal buffer is partially filled (always, except at beginning) + * Complete it, then consume it. + */ + if (state->bufferedSize) { + size_t const loadSize = XXH3_INTERNALBUFFER_SIZE - state->bufferedSize; + XXH_memcpy(state->buffer + state->bufferedSize, input, loadSize); + input += loadSize; + XXH3_consumeStripes(state->acc, + &state->nbStripesSoFar, state->nbStripesPerBlock, + state->buffer, XXH3_INTERNALBUFFER_STRIPES, + secret, state->secretLimit, + f_acc512, f_scramble); + state->bufferedSize = 0; + } + XXH_ASSERT(input < bEnd); + + /* Consume input by a multiple of internal buffer size */ + if (input+XXH3_INTERNALBUFFER_SIZE < bEnd) { + const xxh_u8* const limit = bEnd - XXH3_INTERNALBUFFER_SIZE; + do { + XXH3_consumeStripes(state->acc, + &state->nbStripesSoFar, state->nbStripesPerBlock, + input, XXH3_INTERNALBUFFER_STRIPES, + secret, state->secretLimit, + f_acc512, f_scramble); + input += XXH3_INTERNALBUFFER_SIZE; + } while (inputbuffer + sizeof(state->buffer) - XXH_STRIPE_LEN, input - XXH_STRIPE_LEN, XXH_STRIPE_LEN); + } + XXH_ASSERT(input < bEnd); + + /* Some remaining input (always) : buffer it */ + XXH_memcpy(state->buffer, input, (size_t)(bEnd-input)); + state->bufferedSize = (XXH32_hash_t)(bEnd-input); + } + + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_update(XXH3_state_t* state, const void* input, size_t len) +{ + return XXH3_update(state, (const xxh_u8*)input, len, + XXH3_accumulate_512, XXH3_scrambleAcc); +} + + +XXH_FORCE_INLINE void +XXH3_digest_long (XXH64_hash_t* acc, + const XXH3_state_t* state, + const unsigned char* secret) +{ + /* + * Digest on a local copy. This way, the state remains unaltered, and it can + * continue ingesting more input afterwards. + */ + memcpy(acc, state->acc, sizeof(state->acc)); + if (state->bufferedSize >= XXH_STRIPE_LEN) { + size_t const nbStripes = (state->bufferedSize - 1) / XXH_STRIPE_LEN; + size_t nbStripesSoFar = state->nbStripesSoFar; + XXH3_consumeStripes(acc, + &nbStripesSoFar, state->nbStripesPerBlock, + state->buffer, nbStripes, + secret, state->secretLimit, + XXH3_accumulate_512, XXH3_scrambleAcc); + /* last stripe */ + XXH3_accumulate_512(acc, + state->buffer + state->bufferedSize - XXH_STRIPE_LEN, + secret + state->secretLimit - XXH_SECRET_LASTACC_START); + } else { /* bufferedSize < XXH_STRIPE_LEN */ + xxh_u8 lastStripe[XXH_STRIPE_LEN]; + size_t const catchupSize = XXH_STRIPE_LEN - state->bufferedSize; + XXH_ASSERT(state->bufferedSize > 0); /* there is always some input buffered */ + memcpy(lastStripe, state->buffer + sizeof(state->buffer) - catchupSize, catchupSize); + memcpy(lastStripe + catchupSize, state->buffer, state->bufferedSize); + XXH3_accumulate_512(acc, + lastStripe, + secret + state->secretLimit - XXH_SECRET_LASTACC_START); + } +} + +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (const XXH3_state_t* state) +{ + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + if (state->totalLen > XXH3_MIDSIZE_MAX) { + XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB]; + XXH3_digest_long(acc, state, secret); + return XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)state->totalLen * XXH_PRIME64_1); + } + /* totalLen <= XXH3_MIDSIZE_MAX: digesting a short input */ + if (state->seed) + return XXH3_64bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed); + return XXH3_64bits_withSecret(state->buffer, (size_t)(state->totalLen), + secret, state->secretLimit + XXH_STRIPE_LEN); +} + + +#define XXH_MIN(x, y) (((x) > (y)) ? (y) : (x)) + +XXH_PUBLIC_API void +XXH3_generateSecret(void* secretBuffer, const void* customSeed, size_t customSeedSize) +{ + XXH_ASSERT(secretBuffer != NULL); + if (customSeedSize == 0) { + memcpy(secretBuffer, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + return; + } + XXH_ASSERT(customSeed != NULL); + + { size_t const segmentSize = sizeof(XXH128_hash_t); + size_t const nbSegments = XXH_SECRET_DEFAULT_SIZE / segmentSize; + XXH128_canonical_t scrambler; + XXH64_hash_t seeds[12]; + size_t segnb; + XXH_ASSERT(nbSegments == 12); + XXH_ASSERT(segmentSize * nbSegments == XXH_SECRET_DEFAULT_SIZE); /* exact multiple */ + XXH128_canonicalFromHash(&scrambler, XXH128(customSeed, customSeedSize, 0)); + + /* + * Copy customSeed to seeds[], truncating or repeating as necessary. + */ + { size_t toFill = XXH_MIN(customSeedSize, sizeof(seeds)); + size_t filled = toFill; + memcpy(seeds, customSeed, toFill); + while (filled < sizeof(seeds)) { + toFill = XXH_MIN(filled, sizeof(seeds) - filled); + memcpy((char*)seeds + filled, seeds, toFill); + filled += toFill; + } } + + /* generate secret */ + memcpy(secretBuffer, &scrambler, sizeof(scrambler)); + for (segnb=1; segnb < nbSegments; segnb++) { + size_t const segmentStart = segnb * segmentSize; + XXH128_canonical_t segment; + XXH128_canonicalFromHash(&segment, + XXH128(&scrambler, sizeof(scrambler), XXH_readLE64(seeds + segnb) + segnb) ); + memcpy((char*)secretBuffer + segmentStart, &segment, sizeof(segment)); + } } +} + + +/* ========================================== + * XXH3 128 bits (a.k.a XXH128) + * ========================================== + * XXH3's 128-bit variant has better mixing and strength than the 64-bit variant, + * even without counting the significantly larger output size. + * + * For example, extra steps are taken to avoid the seed-dependent collisions + * in 17-240 byte inputs (See XXH3_mix16B and XXH128_mix32B). + * + * This strength naturally comes at the cost of some speed, especially on short + * lengths. Note that longer hashes are about as fast as the 64-bit version + * due to it using only a slight modification of the 64-bit loop. + * + * XXH128 is also more oriented towards 64-bit machines. It is still extremely + * fast for a _128-bit_ hash on 32-bit (it usually clears XXH64). + */ + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_len_1to3_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + /* A doubled version of 1to3_64b with different constants. */ + XXH_ASSERT(input != NULL); + XXH_ASSERT(1 <= len && len <= 3); + XXH_ASSERT(secret != NULL); + /* + * len = 1: combinedl = { input[0], 0x01, input[0], input[0] } + * len = 2: combinedl = { input[1], 0x02, input[0], input[1] } + * len = 3: combinedl = { input[2], 0x03, input[0], input[1] } + */ + { xxh_u8 const c1 = input[0]; + xxh_u8 const c2 = input[len >> 1]; + xxh_u8 const c3 = input[len - 1]; + xxh_u32 const combinedl = ((xxh_u32)c1 <<16) | ((xxh_u32)c2 << 24) + | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8); + xxh_u32 const combinedh = XXH_rotl32(XXH_swap32(combinedl), 13); + xxh_u64 const bitflipl = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed; + xxh_u64 const bitfliph = (XXH_readLE32(secret+8) ^ XXH_readLE32(secret+12)) - seed; + xxh_u64 const keyed_lo = (xxh_u64)combinedl ^ bitflipl; + xxh_u64 const keyed_hi = (xxh_u64)combinedh ^ bitfliph; + XXH128_hash_t h128; + h128.low64 = XXH64_avalanche(keyed_lo); + h128.high64 = XXH64_avalanche(keyed_hi); + return h128; + } +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_len_4to8_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(4 <= len && len <= 8); + seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32; + { xxh_u32 const input_lo = XXH_readLE32(input); + xxh_u32 const input_hi = XXH_readLE32(input + len - 4); + xxh_u64 const input_64 = input_lo + ((xxh_u64)input_hi << 32); + xxh_u64 const bitflip = (XXH_readLE64(secret+16) ^ XXH_readLE64(secret+24)) + seed; + xxh_u64 const keyed = input_64 ^ bitflip; + + /* Shift len to the left to ensure it is even, this avoids even multiplies. */ + XXH128_hash_t m128 = XXH_mult64to128(keyed, XXH_PRIME64_1 + (len << 2)); + + m128.high64 += (m128.low64 << 1); + m128.low64 ^= (m128.high64 >> 3); + + m128.low64 = XXH_xorshift64(m128.low64, 35); + m128.low64 *= 0x9FB21C651E98DF25ULL; + m128.low64 = XXH_xorshift64(m128.low64, 28); + m128.high64 = XXH3_avalanche(m128.high64); + return m128; + } +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_len_9to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(9 <= len && len <= 16); + { xxh_u64 const bitflipl = (XXH_readLE64(secret+32) ^ XXH_readLE64(secret+40)) - seed; + xxh_u64 const bitfliph = (XXH_readLE64(secret+48) ^ XXH_readLE64(secret+56)) + seed; + xxh_u64 const input_lo = XXH_readLE64(input); + xxh_u64 input_hi = XXH_readLE64(input + len - 8); + XXH128_hash_t m128 = XXH_mult64to128(input_lo ^ input_hi ^ bitflipl, XXH_PRIME64_1); + /* + * Put len in the middle of m128 to ensure that the length gets mixed to + * both the low and high bits in the 128x64 multiply below. + */ + m128.low64 += (xxh_u64)(len - 1) << 54; + input_hi ^= bitfliph; + /* + * Add the high 32 bits of input_hi to the high 32 bits of m128, then + * add the long product of the low 32 bits of input_hi and XXH_PRIME32_2 to + * the high 64 bits of m128. + * + * The best approach to this operation is different on 32-bit and 64-bit. + */ + if (sizeof(void *) < sizeof(xxh_u64)) { /* 32-bit */ + /* + * 32-bit optimized version, which is more readable. + * + * On 32-bit, it removes an ADC and delays a dependency between the two + * halves of m128.high64, but it generates an extra mask on 64-bit. + */ + m128.high64 += (input_hi & 0xFFFFFFFF00000000ULL) + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2); + } else { + /* + * 64-bit optimized (albeit more confusing) version. + * + * Uses some properties of addition and multiplication to remove the mask: + * + * Let: + * a = input_hi.lo = (input_hi & 0x00000000FFFFFFFF) + * b = input_hi.hi = (input_hi & 0xFFFFFFFF00000000) + * c = XXH_PRIME32_2 + * + * a + (b * c) + * Inverse Property: x + y - x == y + * a + (b * (1 + c - 1)) + * Distributive Property: x * (y + z) == (x * y) + (x * z) + * a + (b * 1) + (b * (c - 1)) + * Identity Property: x * 1 == x + * a + b + (b * (c - 1)) + * + * Substitute a, b, and c: + * input_hi.hi + input_hi.lo + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1)) + * + * Since input_hi.hi + input_hi.lo == input_hi, we get this: + * input_hi + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1)) + */ + m128.high64 += input_hi + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2 - 1); + } + /* m128 ^= XXH_swap64(m128 >> 64); */ + m128.low64 ^= XXH_swap64(m128.high64); + + { /* 128x64 multiply: h128 = m128 * XXH_PRIME64_2; */ + XXH128_hash_t h128 = XXH_mult64to128(m128.low64, XXH_PRIME64_2); + h128.high64 += m128.high64 * XXH_PRIME64_2; + + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = XXH3_avalanche(h128.high64); + return h128; + } } +} + +/* + * Assumption: `secret` size is >= XXH3_SECRET_SIZE_MIN + */ +XXH_FORCE_INLINE XXH128_hash_t +XXH3_len_0to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(len <= 16); + { if (len > 8) return XXH3_len_9to16_128b(input, len, secret, seed); + if (len >= 4) return XXH3_len_4to8_128b(input, len, secret, seed); + if (len) return XXH3_len_1to3_128b(input, len, secret, seed); + { XXH128_hash_t h128; + xxh_u64 const bitflipl = XXH_readLE64(secret+64) ^ XXH_readLE64(secret+72); + xxh_u64 const bitfliph = XXH_readLE64(secret+80) ^ XXH_readLE64(secret+88); + h128.low64 = XXH64_avalanche(seed ^ bitflipl); + h128.high64 = XXH64_avalanche( seed ^ bitfliph); + return h128; + } } +} + +/* + * A bit slower than XXH3_mix16B, but handles multiply by zero better. + */ +XXH_FORCE_INLINE XXH128_hash_t +XXH128_mix32B(XXH128_hash_t acc, const xxh_u8* input_1, const xxh_u8* input_2, + const xxh_u8* secret, XXH64_hash_t seed) +{ + acc.low64 += XXH3_mix16B (input_1, secret+0, seed); + acc.low64 ^= XXH_readLE64(input_2) + XXH_readLE64(input_2 + 8); + acc.high64 += XXH3_mix16B (input_2, secret+16, seed); + acc.high64 ^= XXH_readLE64(input_1) + XXH_readLE64(input_1 + 8); + return acc; +} + + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_len_17to128_128b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(16 < len && len <= 128); + + { XXH128_hash_t acc; + acc.low64 = len * XXH_PRIME64_1; + acc.high64 = 0; + if (len > 32) { + if (len > 64) { + if (len > 96) { + acc = XXH128_mix32B(acc, input+48, input+len-64, secret+96, seed); + } + acc = XXH128_mix32B(acc, input+32, input+len-48, secret+64, seed); + } + acc = XXH128_mix32B(acc, input+16, input+len-32, secret+32, seed); + } + acc = XXH128_mix32B(acc, input, input+len-16, secret, seed); + { XXH128_hash_t h128; + h128.low64 = acc.low64 + acc.high64; + h128.high64 = (acc.low64 * XXH_PRIME64_1) + + (acc.high64 * XXH_PRIME64_4) + + ((len - seed) * XXH_PRIME64_2); + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64); + return h128; + } + } +} + +XXH_NO_INLINE XXH128_hash_t +XXH3_len_129to240_128b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + + { XXH128_hash_t acc; + int const nbRounds = (int)len / 32; + int i; + acc.low64 = len * XXH_PRIME64_1; + acc.high64 = 0; + for (i=0; i<4; i++) { + acc = XXH128_mix32B(acc, + input + (32 * i), + input + (32 * i) + 16, + secret + (32 * i), + seed); + } + acc.low64 = XXH3_avalanche(acc.low64); + acc.high64 = XXH3_avalanche(acc.high64); + XXH_ASSERT(nbRounds >= 4); + for (i=4 ; i < nbRounds; i++) { + acc = XXH128_mix32B(acc, + input + (32 * i), + input + (32 * i) + 16, + secret + XXH3_MIDSIZE_STARTOFFSET + (32 * (i - 4)), + seed); + } + /* last bytes */ + acc = XXH128_mix32B(acc, + input + len - 16, + input + len - 32, + secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET - 16, + 0ULL - seed); + + { XXH128_hash_t h128; + h128.low64 = acc.low64 + acc.high64; + h128.high64 = (acc.low64 * XXH_PRIME64_1) + + (acc.high64 * XXH_PRIME64_4) + + ((len - seed) * XXH_PRIME64_2); + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64); + return h128; + } + } +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_hashLong_128b_internal(const void* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC; + + XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, secret, secretSize, f_acc512, f_scramble); + + /* converge into final hash */ + XXH_STATIC_ASSERT(sizeof(acc) == 64); + XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + { XXH128_hash_t h128; + h128.low64 = XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)len * XXH_PRIME64_1); + h128.high64 = XXH3_mergeAccs(acc, + secret + secretSize + - sizeof(acc) - XXH_SECRET_MERGEACCS_START, + ~((xxh_u64)len * XXH_PRIME64_2)); + return h128; + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH128_hash_t +XXH3_hashLong_128b_default(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; (void)secret; (void)secretLen; + return XXH3_hashLong_128b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_accumulate_512, XXH3_scrambleAcc); +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSecret(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; + return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, secretLen, + XXH3_accumulate_512, XXH3_scrambleAcc); +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSeed_internal(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + XXH3_f_accumulate_512 f_acc512, + XXH3_f_scrambleAcc f_scramble, + XXH3_f_initCustomSecret f_initSec) +{ + if (seed64 == 0) + return XXH3_hashLong_128b_internal(input, len, + XXH3_kSecret, sizeof(XXH3_kSecret), + f_acc512, f_scramble); + { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + f_initSec(secret, seed64); + return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, sizeof(secret), + f_acc512, f_scramble); + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSeed(const void* input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)secret; (void)secretLen; + return XXH3_hashLong_128b_withSeed_internal(input, len, seed64, + XXH3_accumulate_512, XXH3_scrambleAcc, XXH3_initCustomSecret); +} + +typedef XXH128_hash_t (*XXH3_hashLong128_f)(const void* XXH_RESTRICT, size_t, + XXH64_hash_t, const void* XXH_RESTRICT, size_t); + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_128bits_internal(const void* input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen, + XXH3_hashLong128_f f_hl128) +{ + XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN); + /* + * If an action is to be taken if `secret` conditions are not respected, + * it should be done here. + * For now, it's a contract pre-condition. + * Adding a check and a branch here would cost performance at every hash. + */ + if (len <= 16) + return XXH3_len_0to16_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64); + if (len <= 128) + return XXH3_len_17to128_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_len_129to240_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + return f_hl128(input, len, seed64, secret, secretLen); +} + + +/* === Public XXH128 API === */ + +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(const void* input, size_t len) +{ + return XXH3_128bits_internal(input, len, 0, + XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_hashLong_128b_default); +} + +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSecret(const void* input, size_t len, const void* secret, size_t secretSize) +{ + return XXH3_128bits_internal(input, len, 0, + (const xxh_u8*)secret, secretSize, + XXH3_hashLong_128b_withSecret); +} + +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSeed(const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_128bits_internal(input, len, seed, + XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_hashLong_128b_withSeed); +} + +XXH_PUBLIC_API XXH128_hash_t +XXH128(const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_128bits_withSeed(input, len, seed); +} + + +/* === XXH3 128-bit streaming === */ + +/* + * All the functions are actually the same as for 64-bit streaming variant. + * The only difference is the finalizatiom routine. + */ + +static void +XXH3_128bits_reset_internal(XXH3_state_t* statePtr, + XXH64_hash_t seed, + const xxh_u8* secret, size_t secretSize) +{ + XXH3_64bits_reset_internal(statePtr, seed, secret, secretSize); +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset(XXH3_state_t* statePtr) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_128bits_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_128bits_reset_internal(statePtr, 0, (const xxh_u8*)secret, secretSize); + if (secret == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_128bits_reset_internal(statePtr, seed, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + XXH3_initCustomSecret(statePtr->customSecret, seed); + statePtr->extSecret = NULL; + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_update(XXH3_state_t* state, const void* input, size_t len) +{ + return XXH3_update(state, (const xxh_u8*)input, len, + XXH3_accumulate_512, XXH3_scrambleAcc); +} + +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (const XXH3_state_t* state) +{ + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + if (state->totalLen > XXH3_MIDSIZE_MAX) { + XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB]; + XXH3_digest_long(acc, state, secret); + XXH_ASSERT(state->secretLimit + XXH_STRIPE_LEN >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + { XXH128_hash_t h128; + h128.low64 = XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)state->totalLen * XXH_PRIME64_1); + h128.high64 = XXH3_mergeAccs(acc, + secret + state->secretLimit + XXH_STRIPE_LEN + - sizeof(acc) - XXH_SECRET_MERGEACCS_START, + ~((xxh_u64)state->totalLen * XXH_PRIME64_2)); + return h128; + } + } + /* len <= XXH3_MIDSIZE_MAX : short code */ + if (state->seed) + return XXH3_128bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed); + return XXH3_128bits_withSecret(state->buffer, (size_t)(state->totalLen), + secret, state->secretLimit + XXH_STRIPE_LEN); +} + +/* 128-bit utility functions */ + +#include /* memcmp, memcpy */ + +/* return : 1 is equal, 0 if different */ +XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2) +{ + /* note : XXH128_hash_t is compact, it has no padding byte */ + return !(memcmp(&h1, &h2, sizeof(h1))); +} + +/* This prototype is compatible with stdlib's qsort(). + * return : >0 if *h128_1 > *h128_2 + * <0 if *h128_1 < *h128_2 + * =0 if *h128_1 == *h128_2 */ +XXH_PUBLIC_API int XXH128_cmp(const void* h128_1, const void* h128_2) +{ + XXH128_hash_t const h1 = *(const XXH128_hash_t*)h128_1; + XXH128_hash_t const h2 = *(const XXH128_hash_t*)h128_2; + int const hcmp = (h1.high64 > h2.high64) - (h2.high64 > h1.high64); + /* note : bets that, in most cases, hash values are different */ + if (hcmp) return hcmp; + return (h1.low64 > h2.low64) - (h2.low64 > h1.low64); +} + + +/*====== Canonical representation ======*/ +XXH_PUBLIC_API void +XXH128_canonicalFromHash(XXH128_canonical_t* dst, XXH128_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH128_canonical_t) == sizeof(XXH128_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) { + hash.high64 = XXH_swap64(hash.high64); + hash.low64 = XXH_swap64(hash.low64); + } + memcpy(dst, &hash.high64, sizeof(hash.high64)); + memcpy((char*)dst + sizeof(hash.high64), &hash.low64, sizeof(hash.low64)); +} + +XXH_PUBLIC_API XXH128_hash_t +XXH128_hashFromCanonical(const XXH128_canonical_t* src) +{ + XXH128_hash_t h; + h.high64 = XXH_readBE64(src); + h.low64 = XXH_readBE64(src->digest + 8); + return h; +} + +/* Pop our optimization override from above */ +#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \ + && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__OPTIMIZE__) && !defined(__OPTIMIZE_SIZE__) /* respect -O0 and -Os */ +# pragma GCC pop_options +#endif + +#endif /* XXH3_H_1397135465 */ \ No newline at end of file diff --git a/src/inference/xxh/xxhash.c b/src/inference/xxh/xxhash.c new file mode 100644 index 0000000000000000000000000000000000000000..728fd686986ad0d0aadad4ac29270c36c52c767c --- /dev/null +++ b/src/inference/xxh/xxhash.c @@ -0,0 +1,43 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Copyright (C) 2012-2020 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + + +/* + * xxhash.c instantiates functions defined in xxhash.h + */ + +#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */ +#define XXH_IMPLEMENTATION /* access definitions */ + +#include "xxhash.h" \ No newline at end of file diff --git a/src/inference/xxh/xxhash.h b/src/inference/xxh/xxhash.h new file mode 100644 index 0000000000000000000000000000000000000000..fc4e52d3fe6f095c02e202c213b76e728816f13c --- /dev/null +++ b/src/inference/xxh/xxhash.h @@ -0,0 +1,2076 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Header File + * Copyright (C) 2012-2020 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/* TODO: update */ +/* Notice extracted from xxHash homepage: + +xxHash is an extremely fast hash algorithm, running at RAM speed limits. +It also successfully passes all tests from the SMHasher suite. + +Comparison (single thread, Windows Seven 32 bits, using SMHasher on a Core 2 Duo @3GHz) + +Name Speed Q.Score Author +xxHash 5.4 GB/s 10 +CrapWow 3.2 GB/s 2 Andrew +MumurHash 3a 2.7 GB/s 10 Austin Appleby +SpookyHash 2.0 GB/s 10 Bob Jenkins +SBox 1.4 GB/s 9 Bret Mulvey +Lookup3 1.2 GB/s 9 Bob Jenkins +SuperFastHash 1.2 GB/s 1 Paul Hsieh +CityHash64 1.05 GB/s 10 Pike & Alakuijala +FNV 0.55 GB/s 5 Fowler, Noll, Vo +CRC32 0.43 GB/s 9 +MD5-32 0.33 GB/s 10 Ronald L. Rivest +SHA1-32 0.28 GB/s 10 + +Q.Score is a measure of quality of the hash function. +It depends on successfully passing SMHasher test set. +10 is a perfect score. + +Note: SMHasher's CRC32 implementation is not the fastest one. +Other speed-oriented implementations can be faster, +especially in combination with PCLMUL instruction: +https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html?showComment=1552696407071#c3490092340461170735 + +A 64-bit version, named XXH64, is available since r35. +It offers much better speed, but for 64-bit applications only. +Name Speed on 64 bits Speed on 32 bits +XXH64 13.8 GB/s 1.9 GB/s +XXH32 6.8 GB/s 6.0 GB/s +*/ + +#if defined (__cplusplus) +extern "C" { +#endif + +/* **************************** + * INLINE mode + ******************************/ +/*! + * XXH_INLINE_ALL (and XXH_PRIVATE_API) + * Use these build macros to inline xxhash into the target unit. + * Inlining improves performance on small inputs, especially when the length is + * expressed as a compile-time constant: + * + * https://fastcompression.blogspot.com/2018/03/xxhash-for-small-keys-impressive-power.html + * + * It also keeps xxHash symbols private to the unit, so they are not exported. + * + * Usage: + * #define XXH_INLINE_ALL + * #include "xxhash.h" + * + * Do not compile and link xxhash.o as a separate object, as it is not useful. + */ +#if (defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)) \ + && !defined(XXH_INLINE_ALL_31684351384) + /* this section should be traversed only once */ +# define XXH_INLINE_ALL_31684351384 + /* give access to the advanced API, required to compile implementations */ +# undef XXH_STATIC_LINKING_ONLY /* avoid macro redef */ +# define XXH_STATIC_LINKING_ONLY + /* make all functions private */ +# undef XXH_PUBLIC_API +# if defined(__GNUC__) +# define XXH_PUBLIC_API static __inline __attribute__((unused)) +# elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +# define XXH_PUBLIC_API static inline +# elif defined(_MSC_VER) +# define XXH_PUBLIC_API static __inline +# else + /* note: this version may generate warnings for unused static functions */ +# define XXH_PUBLIC_API static +# endif + + /* + * This part deals with the special case where a unit wants to inline xxHash, + * but "xxhash.h" has previously been included without XXH_INLINE_ALL, such + * as part of some previously included *.h header file. + * Without further action, the new include would just be ignored, + * and functions would effectively _not_ be inlined (silent failure). + * The following macros solve this situation by prefixing all inlined names, + * avoiding naming collision with previous inclusions. + */ +# ifdef XXH_NAMESPACE +# error "XXH_INLINE_ALL with XXH_NAMESPACE is not supported" + /* + * Note: Alternative: #undef all symbols (it's a pretty large list). + * Without #error: it compiles, but functions are actually not inlined. + */ +# endif +# define XXH_NAMESPACE XXH_INLINE_ + /* + * Some identifiers (enums, type names) are not symbols, but they must + * still be renamed to avoid redeclaration. + * Alternative solution: do not redeclare them. + * However, this requires some #ifdefs, and is a more dispersed action. + * Meanwhile, renaming can be achieved in a single block + */ +# define XXH_IPREF(Id) XXH_INLINE_ ## Id +# define XXH_OK XXH_IPREF(XXH_OK) +# define XXH_ERROR XXH_IPREF(XXH_ERROR) +# define XXH_errorcode XXH_IPREF(XXH_errorcode) +# define XXH32_canonical_t XXH_IPREF(XXH32_canonical_t) +# define XXH64_canonical_t XXH_IPREF(XXH64_canonical_t) +# define XXH128_canonical_t XXH_IPREF(XXH128_canonical_t) +# define XXH32_state_s XXH_IPREF(XXH32_state_s) +# define XXH32_state_t XXH_IPREF(XXH32_state_t) +# define XXH64_state_s XXH_IPREF(XXH64_state_s) +# define XXH64_state_t XXH_IPREF(XXH64_state_t) +# define XXH3_state_s XXH_IPREF(XXH3_state_s) +# define XXH3_state_t XXH_IPREF(XXH3_state_t) +# define XXH128_hash_t XXH_IPREF(XXH128_hash_t) + /* Ensure the header is parsed again, even if it was previously included */ +# undef XXHASH_H_5627135585666179 +# undef XXHASH_H_STATIC_13879238742 +#endif /* XXH_INLINE_ALL || XXH_PRIVATE_API */ + + + +/* **************************************************************** + * Stable API + *****************************************************************/ +#ifndef XXHASH_H_5627135585666179 +#define XXHASH_H_5627135585666179 1 + +/* specific declaration modes for Windows */ +#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API) +# if defined(WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT)) +# ifdef XXH_EXPORT +# define XXH_PUBLIC_API __declspec(dllexport) +# elif XXH_IMPORT +# define XXH_PUBLIC_API __declspec(dllimport) +# endif +# else +# define XXH_PUBLIC_API /* do nothing */ +# endif +#endif + +/*! + * XXH_NAMESPACE, aka Namespace Emulation: + * + * If you want to include _and expose_ xxHash functions from within your own + * library, but also want to avoid symbol collisions with other libraries which + * may also include xxHash, you can use XXH_NAMESPACE to automatically prefix + * any public symbol from xxhash library with the value of XXH_NAMESPACE + * (therefore, avoid empty or numeric values). + * + * Note that no change is required within the calling program as long as it + * includes `xxhash.h`: Regular symbol names will be automatically translated + * by this header. + */ +#ifdef XXH_NAMESPACE +# define XXH_CAT(A,B) A##B +# define XXH_NAME2(A,B) XXH_CAT(A,B) +# define XXH_versionNumber XXH_NAME2(XXH_NAMESPACE, XXH_versionNumber) +# define XXH32 XXH_NAME2(XXH_NAMESPACE, XXH32) +# define XXH32_createState XXH_NAME2(XXH_NAMESPACE, XXH32_createState) +# define XXH32_freeState XXH_NAME2(XXH_NAMESPACE, XXH32_freeState) +# define XXH32_reset XXH_NAME2(XXH_NAMESPACE, XXH32_reset) +# define XXH32_update XXH_NAME2(XXH_NAMESPACE, XXH32_update) +# define XXH32_digest XXH_NAME2(XXH_NAMESPACE, XXH32_digest) +# define XXH32_copyState XXH_NAME2(XXH_NAMESPACE, XXH32_copyState) +# define XXH32_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH32_canonicalFromHash) +# define XXH32_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH32_hashFromCanonical) +# define XXH64 XXH_NAME2(XXH_NAMESPACE, XXH64) +# define XXH64_createState XXH_NAME2(XXH_NAMESPACE, XXH64_createState) +# define XXH64_freeState XXH_NAME2(XXH_NAMESPACE, XXH64_freeState) +# define XXH64_reset XXH_NAME2(XXH_NAMESPACE, XXH64_reset) +# define XXH64_update XXH_NAME2(XXH_NAMESPACE, XXH64_update) +# define XXH64_digest XXH_NAME2(XXH_NAMESPACE, XXH64_digest) +# define XXH64_copyState XXH_NAME2(XXH_NAMESPACE, XXH64_copyState) +# define XXH64_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH64_canonicalFromHash) +# define XXH64_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH64_hashFromCanonical) +#endif + + +/* ************************************* +* Version +***************************************/ +#define XXH_VERSION_MAJOR 0 +#define XXH_VERSION_MINOR 7 +#define XXH_VERSION_RELEASE 4 +#define XXH_VERSION_NUMBER (XXH_VERSION_MAJOR *100*100 + XXH_VERSION_MINOR *100 + XXH_VERSION_RELEASE) +XXH_PUBLIC_API unsigned XXH_versionNumber (void); + + +/* **************************** +* Definitions +******************************/ +#include /* size_t */ +typedef enum { XXH_OK=0, XXH_ERROR } XXH_errorcode; + + +/*-********************************************************************** +* 32-bit hash +************************************************************************/ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# include + typedef uint32_t XXH32_hash_t; +#else +# include +# if UINT_MAX == 0xFFFFFFFFUL + typedef unsigned int XXH32_hash_t; +# else +# if ULONG_MAX == 0xFFFFFFFFUL + typedef unsigned long XXH32_hash_t; +# else +# error "unsupported platform: need a 32-bit type" +# endif +# endif +#endif + +/*! + * XXH32(): + * Calculate the 32-bit hash of sequence "length" bytes stored at memory address "input". + * The memory between input & input+length must be valid (allocated and read-accessible). + * "seed" can be used to alter the result predictably. + * Speed on Core 2 Duo @ 3 GHz (single thread, SMHasher benchmark): 5.4 GB/s + * + * Note: XXH3 provides competitive speed for both 32-bit and 64-bit systems, + * and offers true 64/128 bit hash results. It provides a superior level of + * dispersion, and greatly reduces the risks of collisions. + */ +XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t length, XXH32_hash_t seed); + +/******* Streaming *******/ + +/* + * Streaming functions generate the xxHash value from an incrememtal input. + * This method is slower than single-call functions, due to state management. + * For small inputs, prefer `XXH32()` and `XXH64()`, which are better optimized. + * + * An XXH state must first be allocated using `XXH*_createState()`. + * + * Start a new hash by initializing the state with a seed using `XXH*_reset()`. + * + * Then, feed the hash state by calling `XXH*_update()` as many times as necessary. + * + * The function returns an error code, with 0 meaning OK, and any other value + * meaning there is an error. + * + * Finally, a hash value can be produced anytime, by using `XXH*_digest()`. + * This function returns the nn-bits hash as an int or long long. + * + * It's still possible to continue inserting input into the hash state after a + * digest, and generate new hash values later on by invoking `XXH*_digest()`. + * + * When done, release the state using `XXH*_freeState()`. + */ + +typedef struct XXH32_state_s XXH32_state_t; /* incomplete type */ +XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void); +XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr); +XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dst_state, const XXH32_state_t* src_state); + +XXH_PUBLIC_API XXH_errorcode XXH32_reset (XXH32_state_t* statePtr, XXH32_hash_t seed); +XXH_PUBLIC_API XXH_errorcode XXH32_update (XXH32_state_t* statePtr, const void* input, size_t length); +XXH_PUBLIC_API XXH32_hash_t XXH32_digest (const XXH32_state_t* statePtr); + +/******* Canonical representation *******/ + +/* + * The default return values from XXH functions are unsigned 32 and 64 bit + * integers. + * This the simplest and fastest format for further post-processing. + * + * However, this leaves open the question of what is the order on the byte level, + * since little and big endian conventions will store the same number differently. + * + * The canonical representation settles this issue by mandating big-endian + * convention, the same convention as human-readable numbers (large digits first). + * + * When writing hash values to storage, sending them over a network, or printing + * them, it's highly recommended to use the canonical representation to ensure + * portability across a wider range of systems, present and future. + * + * The following functions allow transformation of hash values to and from + * canonical format. + */ + +typedef struct { unsigned char digest[4]; } XXH32_canonical_t; +XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash); +XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src); + + +#ifndef XXH_NO_LONG_LONG +/*-********************************************************************** +* 64-bit hash +************************************************************************/ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# include + typedef uint64_t XXH64_hash_t; +#else + /* the following type must have a width of 64-bit */ + typedef unsigned long long XXH64_hash_t; +#endif + +/*! + * XXH64(): + * Returns the 64-bit hash of sequence of length @length stored at memory + * address @input. + * @seed can be used to alter the result predictably. + * + * This function usually runs faster on 64-bit systems, but slower on 32-bit + * systems (see benchmark). + * + * Note: XXH3 provides competitive speed for both 32-bit and 64-bit systems, + * and offers true 64/128 bit hash results. It provides a superior level of + * dispersion, and greatly reduces the risks of collisions. + */ +XXH_PUBLIC_API XXH64_hash_t XXH64 (const void* input, size_t length, XXH64_hash_t seed); + +/******* Streaming *******/ +typedef struct XXH64_state_s XXH64_state_t; /* incomplete type */ +XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void); +XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr); +XXH_PUBLIC_API void XXH64_copyState(XXH64_state_t* dst_state, const XXH64_state_t* src_state); + +XXH_PUBLIC_API XXH_errorcode XXH64_reset (XXH64_state_t* statePtr, XXH64_hash_t seed); +XXH_PUBLIC_API XXH_errorcode XXH64_update (XXH64_state_t* statePtr, const void* input, size_t length); +XXH_PUBLIC_API XXH64_hash_t XXH64_digest (const XXH64_state_t* statePtr); + +/******* Canonical representation *******/ +typedef struct { unsigned char digest[sizeof(XXH64_hash_t)]; } XXH64_canonical_t; +XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH64_canonical_t* dst, XXH64_hash_t hash); +XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src); + + +#endif /* XXH_NO_LONG_LONG */ + +#endif /* XXHASH_H_5627135585666179 */ + + + +#if defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) +#define XXHASH_H_STATIC_13879238742 +/* **************************************************************************** + * This section contains declarations which are not guaranteed to remain stable. + * They may change in future versions, becoming incompatible with a different + * version of the library. + * These declarations should only be used with static linking. + * Never use them in association with dynamic linking! + ***************************************************************************** */ + +/* + * These definitions are only present to allow static allocation of an XXH + * state, for example, on the stack or in a struct. + * Never **ever** access members directly. + */ + +struct XXH32_state_s { + XXH32_hash_t total_len_32; + XXH32_hash_t large_len; + XXH32_hash_t v1; + XXH32_hash_t v2; + XXH32_hash_t v3; + XXH32_hash_t v4; + XXH32_hash_t mem32[4]; + XXH32_hash_t memsize; + XXH32_hash_t reserved; /* never read nor write, might be removed in a future version */ +}; /* typedef'd to XXH32_state_t */ + + +#ifndef XXH_NO_LONG_LONG /* defined when there is no 64-bit support */ + +struct XXH64_state_s { + XXH64_hash_t total_len; + XXH64_hash_t v1; + XXH64_hash_t v2; + XXH64_hash_t v3; + XXH64_hash_t v4; + XXH64_hash_t mem64[4]; + XXH32_hash_t memsize; + XXH32_hash_t reserved32; /* required for padding anyway */ + XXH64_hash_t reserved64; /* never read nor write, might be removed in a future version */ +}; /* typedef'd to XXH64_state_t */ + + +/*-********************************************************************** +* XXH3 +* New experimental hash +************************************************************************/ + +/* ************************************************************************ + * XXH3 is a new hash algorithm featuring: + * - Improved speed for both small and large inputs + * - True 64-bit and 128-bit outputs + * - SIMD acceleration + * - Improved 32-bit viability + * + * Speed analysis methodology is explained here: + * + * https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html + * + * In general, expect XXH3 to run about ~2x faster on large inputs and >3x + * faster on small ones compared to XXH64, though exact differences depend on + * the platform. + * + * The algorithm is portable: Like XXH32 and XXH64, it generates the same hash + * on all platforms. + * + * It benefits greatly from SIMD and 64-bit arithmetic, but does not require it. + * + * Almost all 32-bit and 64-bit targets that can run XXH32 smoothly can run + * XXH3 at competitive speeds, even if XXH64 runs slowly. Further details are + * explained in the implementation. + * + * Optimized implementations are provided for AVX512, AVX2, SSE2, NEON, POWER8, + * ZVector and scalar targets. This can be controlled with the XXH_VECTOR macro. + * + * XXH3 offers 2 variants, _64bits and _128bits. + * When only 64 bits are needed, prefer calling the _64bits variant, as it + * reduces the amount of mixing, resulting in faster speed on small inputs. + * + * It's also generally simpler to manipulate a scalar return type than a struct. + * + * The 128-bit version adds additional strength, but it is slightly slower. + * + * The XXH3 algorithm is still in development. + * The results it produces may still change in future versions. + * + * Results produced by v0.7.x are not comparable with results from v0.7.y. + * However, the API is completely stable, and it can safely be used for + * ephemeral data (local sessions). + * + * Avoid storing values in long-term storage until the algorithm is finalized. + * XXH3's return values will be officially finalized upon reaching v0.8.0. + * + * After which, return values of XXH3 and XXH128 will no longer change in + * future versions. + * + * The API supports one-shot hashing, streaming mode, and custom secrets. + */ + +#ifdef XXH_NAMESPACE +# define XXH3_64bits XXH_NAME2(XXH_NAMESPACE, XXH3_64bits) +# define XXH3_64bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecret) +# define XXH3_64bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSeed) + +# define XXH3_createState XXH_NAME2(XXH_NAMESPACE, XXH3_createState) +# define XXH3_freeState XXH_NAME2(XXH_NAMESPACE, XXH3_freeState) +# define XXH3_copyState XXH_NAME2(XXH_NAMESPACE, XXH3_copyState) + +# define XXH3_64bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset) +# define XXH3_64bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSeed) +# define XXH3_64bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecret) +# define XXH3_64bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_update) +# define XXH3_64bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_digest) + +# define XXH3_generateSecret XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret) +#endif + +/* XXH3_64bits(): + * default 64-bit variant, using default secret and default seed of 0. + * It's the fastest variant. */ +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(const void* data, size_t len); + +/* + * XXH3_64bits_withSeed(): + * This variant generates a custom secret on the fly based on the default + * secret, altered using the `seed` value. + * While this operation is decently fast, note that it's not completely free. + * Note: seed==0 produces the same results as XXH3_64bits(). + */ +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_withSeed(const void* data, size_t len, XXH64_hash_t seed); + +/* + * XXH3_64bits_withSecret(): + * It's possible to provide any blob of bytes as a "secret" to generate the hash. + * This makes it more difficult for an external actor to prepare an intentional collision. + * The main condition is that secretSize *must* be large enough (>= XXH3_SECRET_SIZE_MIN). + * However, the quality of the hash highly depends on the secret's entropy. + * Technically, the secret must look like a bunch of random bytes. + * Avoid "trivial" or structured data such as repeated sequences or a text document. + * Whenever unsure about the "randonmess" of the blob of bytes, + * consider relabelling it as a "custom seed" instead, + * and employ "XXH3_generateSecret()" (see below) + * to generate a high quality secret derived from this custom seed. + */ +#define XXH3_SECRET_SIZE_MIN 136 +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize); + + +/* streaming 64-bit */ + +#if defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* C11+ */ +# include +# define XXH_ALIGN(n) alignas(n) +#elif defined(__GNUC__) +# define XXH_ALIGN(n) __attribute__ ((aligned(n))) +#elif defined(_MSC_VER) +# define XXH_ALIGN(n) __declspec(align(n)) +#else +# define XXH_ALIGN(n) /* disabled */ +#endif + +/* Old GCC versions only accept the attribute after the type in structures. */ +#if !(defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) /* C11+ */ \ + && defined(__GNUC__) +# define XXH_ALIGN_MEMBER(align, type) type XXH_ALIGN(align) +#else +# define XXH_ALIGN_MEMBER(align, type) XXH_ALIGN(align) type +#endif + +typedef struct XXH3_state_s XXH3_state_t; + +#define XXH3_INTERNALBUFFER_SIZE 256 +#define XXH3_SECRET_DEFAULT_SIZE 192 +struct XXH3_state_s { + XXH_ALIGN_MEMBER(64, XXH64_hash_t acc[8]); + /* used to store a custom secret generated from a seed */ + XXH_ALIGN_MEMBER(64, unsigned char customSecret[XXH3_SECRET_DEFAULT_SIZE]); + XXH_ALIGN_MEMBER(64, unsigned char buffer[XXH3_INTERNALBUFFER_SIZE]); + XXH32_hash_t bufferedSize; + XXH32_hash_t reserved32; + size_t nbStripesPerBlock; + size_t nbStripesSoFar; + size_t secretLimit; + XXH64_hash_t totalLen; + XXH64_hash_t seed; + XXH64_hash_t reserved64; + const unsigned char* extSecret; /* reference to external secret; + * if == NULL, use .customSecret instead */ + /* note: there may be some padding at the end due to alignment on 64 bytes */ +}; /* typedef'd to XXH3_state_t */ + +#undef XXH_ALIGN_MEMBER + +/* + * Streaming requires state maintenance. + * This operation costs memory and CPU. + * As a consequence, streaming is slower than one-shot hashing. + * For better performance, prefer one-shot functions whenever possible. + */ +XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void); +XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr); +XXH_PUBLIC_API void XXH3_copyState(XXH3_state_t* dst_state, const XXH3_state_t* src_state); + + +/* + * XXH3_64bits_reset(): + * Initialize with the default parameters. + * The result will be equivalent to `XXH3_64bits()`. + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset(XXH3_state_t* statePtr); +/* + * XXH3_64bits_reset_withSeed(): + * Generate a custom secret from `seed`, and store it into `statePtr`. + * digest will be equivalent to `XXH3_64bits_withSeed()`. + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed); +/* + * XXH3_64bits_reset_withSecret(): + * `secret` is referenced, it _must outlive_ the hash streaming session. + * Similar to one-shot API, `secretSize` must be >= `XXH3_SECRET_SIZE_MIN`, + * and the quality of the hash depends on secret's entropy, + * meaning that the secret should look like a bunch of random bytes. + * When in doubt about the randomness of a candidate `secret`, + * consider employing `XXH3_generateSecret()` instead (see below). + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize); + +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_update (XXH3_state_t* statePtr, const void* input, size_t length); +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (const XXH3_state_t* statePtr); + + +/* 128-bit */ + +#ifdef XXH_NAMESPACE +# define XXH128 XXH_NAME2(XXH_NAMESPACE, XXH128) +# define XXH3_128bits XXH_NAME2(XXH_NAMESPACE, XXH3_128bits) +# define XXH3_128bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSeed) +# define XXH3_128bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecret) + +# define XXH3_128bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset) +# define XXH3_128bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSeed) +# define XXH3_128bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecret) +# define XXH3_128bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_update) +# define XXH3_128bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_digest) + +# define XXH128_isEqual XXH_NAME2(XXH_NAMESPACE, XXH128_isEqual) +# define XXH128_cmp XXH_NAME2(XXH_NAMESPACE, XXH128_cmp) +# define XXH128_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH128_canonicalFromHash) +# define XXH128_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH128_hashFromCanonical) +#endif + +typedef struct { + XXH64_hash_t low64; + XXH64_hash_t high64; +} XXH128_hash_t; + +XXH_PUBLIC_API XXH128_hash_t XXH128(const void* data, size_t len, XXH64_hash_t seed); +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(const void* data, size_t len); +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_withSeed(const void* data, size_t len, XXH64_hash_t seed); /* == XXH128() */ +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_withSecret(const void* data, size_t len, const void* secret, size_t secretSize); + +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset(XXH3_state_t* statePtr); +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSeed(XXH3_state_t* statePtr, XXH64_hash_t seed); +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSecret(XXH3_state_t* statePtr, const void* secret, size_t secretSize); + +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_update (XXH3_state_t* statePtr, const void* input, size_t length); +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (const XXH3_state_t* statePtr); + + +/* Note: For better performance, these functions can be inlined using XXH_INLINE_ALL */ + +/*! + * XXH128_isEqual(): + * Return: 1 if `h1` and `h2` are equal, 0 if they are not. + */ +XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2); + +/*! + * XXH128_cmp(): + * + * This comparator is compatible with stdlib's `qsort()`/`bsearch()`. + * + * return: >0 if *h128_1 > *h128_2 + * =0 if *h128_1 == *h128_2 + * <0 if *h128_1 < *h128_2 + */ +XXH_PUBLIC_API int XXH128_cmp(const void* h128_1, const void* h128_2); + + +/******* Canonical representation *******/ +typedef struct { unsigned char digest[sizeof(XXH128_hash_t)]; } XXH128_canonical_t; +XXH_PUBLIC_API void XXH128_canonicalFromHash(XXH128_canonical_t* dst, XXH128_hash_t hash); +XXH_PUBLIC_API XXH128_hash_t XXH128_hashFromCanonical(const XXH128_canonical_t* src); + + +/* === Experimental API === */ +/* Symbols defined below must be considered tied to a specific library version. */ + +/* + * XXH3_generateSecret(): + * + * Derive a high-entropy secret from any user-defined content, named customSeed. + * The generated secret can be used in combination with `*_withSecret()` functions. + * The `_withSecret()` variants are useful to provide a higher level of protection than 64-bit seed, + * as it becomes much more difficult for an external actor to guess how to impact the calculation logic. + * + * The function accepts as input a custom seed of any length and any content, + * and derives from it a high-entropy secret of length XXH3_SECRET_DEFAULT_SIZE + * into an already allocated buffer secretBuffer. + * The generated secret is _always_ XXH_SECRET_DEFAULT_SIZE bytes long. + * + * The generated secret can then be used with any `*_withSecret()` variant. + * Functions `XXH3_128bits_withSecret()`, `XXH3_64bits_withSecret()`, + * `XXH3_128bits_reset_withSecret()` and `XXH3_64bits_reset_withSecret()` + * are part of this list. They all accept a `secret` parameter + * which must be very long for implementation reasons (>= XXH3_SECRET_SIZE_MIN) + * _and_ feature very high entropy (consist of random-looking bytes). + * These conditions can be a high bar to meet, so + * this function can be used to generate a secret of proper quality. + * + * customSeed can be anything. It can have any size, even small ones, + * and its content can be anything, even stupidly "low entropy" source such as a bunch of zeroes. + * The resulting `secret` will nonetheless provide all expected qualities. + * + * Supplying NULL as the customSeed copies the default secret into `secretBuffer`. + * When customSeedSize > 0, supplying NULL as customSeed is undefined behavior. + */ +XXH_PUBLIC_API void XXH3_generateSecret(void* secretBuffer, const void* customSeed, size_t customSeedSize); + + +#endif /* XXH_NO_LONG_LONG */ + + +#if defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) +# define XXH_IMPLEMENTATION +#endif + +#endif /* defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) */ + + +/* ======================================================================== */ +/* ======================================================================== */ +/* ======================================================================== */ + + +/*-********************************************************************** + * xxHash implementation + *-********************************************************************** + * xxHash's implementation used to be found in xxhash.c. + * + * However, code inlining requires the implementation to be visible to the + * compiler, usually within the header. + * + * As a workaround, xxhash.c used to be included within xxhash.h. This caused + * some issues with some build systems, especially ones which treat .c files + * as source files. + * + * Therefore, the implementation is now directly integrated within xxhash.h. + * Another small advantage is that xxhash.c is no longer needed in /include. + ************************************************************************/ + +#if ( defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) \ + || defined(XXH_IMPLEMENTATION) ) && !defined(XXH_IMPLEM_13a8737387) +# define XXH_IMPLEM_13a8737387 + +/* ************************************* +* Tuning parameters +***************************************/ +/*! + * XXH_FORCE_MEMORY_ACCESS: + * By default, access to unaligned memory is controlled by `memcpy()`, which is + * safe and portable. + * + * Unfortunately, on some target/compiler combinations, the generated assembly + * is sub-optimal. + * + * The below switch allow to select a different access method for improved + * performance. + * Method 0 (default): + * Use `memcpy()`. Safe and portable. + * Method 1: + * `__attribute__((packed))` statement. It depends on compiler extensions + * and is therefore not portable. + * This method is safe if your compiler supports it, and *generally* as + * fast or faster than `memcpy`. + * Method 2: + * Direct access via cast. This method doesn't depend on the compiler but + * violates the C standard. + * It can generate buggy code on targets which do not support unaligned + * memory accesses. + * But in some circumstances, it's the only known way to get the most + * performance (ie GCC + ARMv6) + * Method 3: + * Byteshift. This can generate the best code on old compilers which don't + * inline small `memcpy()` calls, and it might also be faster on big-endian + * systems which lack a native byteswap instruction. + * See https://stackoverflow.com/a/32095106/646947 for details. + * Prefer these methods in priority order (0 > 1 > 2 > 3) + */ +#ifndef XXH_FORCE_MEMORY_ACCESS /* can be defined externally, on command line for example */ +# if !defined(__clang__) && defined(__GNUC__) && defined(__ARM_FEATURE_UNALIGNED) && defined(__ARM_ARCH) && (__ARM_ARCH == 6) +# define XXH_FORCE_MEMORY_ACCESS 2 +# elif !defined(__clang__) && ((defined(__INTEL_COMPILER) && !defined(_WIN32)) || \ + (defined(__GNUC__) && (defined(__ARM_ARCH) && __ARM_ARCH >= 7))) +# define XXH_FORCE_MEMORY_ACCESS 1 +# endif +#endif + +/*! + * XXH_ACCEPT_NULL_INPUT_POINTER: + * If the input pointer is NULL, xxHash's default behavior is to dereference it, + * triggering a segfault. + * When this macro is enabled, xxHash actively checks the input for a null pointer. + * If it is, the result for null input pointers is the same as a zero-length input. + */ +#ifndef XXH_ACCEPT_NULL_INPUT_POINTER /* can be defined externally */ +# define XXH_ACCEPT_NULL_INPUT_POINTER 0 +#endif + +/*! + * XXH_FORCE_ALIGN_CHECK: + * This is an important performance trick + * for architectures without decent unaligned memory access performance. + * It checks for input alignment, and when conditions are met, + * uses a "fast path" employing direct 32-bit/64-bit read, + * resulting in _dramatically faster_ read speed. + * + * The check costs one initial branch per hash, which is generally negligible, but not zero. + * Moreover, it's not useful to generate binary for an additional code path + * if memory access uses same instruction for both aligned and unaligned adresses. + * + * In these cases, the alignment check can be removed by setting this macro to 0. + * Then the code will always use unaligned memory access. + * Align check is automatically disabled on x86, x64 & arm64, + * which are platforms known to offer good unaligned memory accesses performance. + * + * This option does not affect XXH3 (only XXH32 and XXH64). + */ +#ifndef XXH_FORCE_ALIGN_CHECK /* can be defined externally */ +# if defined(__i386) || defined(__x86_64__) || defined(__aarch64__) \ + || defined(_M_IX86) || defined(_M_X64) || defined(_M_ARM64) /* visual */ +# define XXH_FORCE_ALIGN_CHECK 0 +# else +# define XXH_FORCE_ALIGN_CHECK 1 +# endif +#endif + +/*! + * XXH_NO_INLINE_HINTS: + * + * By default, xxHash tries to force the compiler to inline almost all internal + * functions. + * + * This can usually improve performance due to reduced jumping and improved + * constant folding, but significantly increases the size of the binary which + * might not be favorable. + * + * Additionally, sometimes the forced inlining can be detrimental to performance, + * depending on the architecture. + * + * XXH_NO_INLINE_HINTS marks all internal functions as static, giving the + * compiler full control on whether to inline or not. + * + * When not optimizing (-O0), optimizing for size (-Os, -Oz), or using + * -fno-inline with GCC or Clang, this will automatically be defined. + */ +#ifndef XXH_NO_INLINE_HINTS +# if defined(__OPTIMIZE_SIZE__) /* -Os, -Oz */ \ + || defined(__NO_INLINE__) /* -O0, -fno-inline */ +# define XXH_NO_INLINE_HINTS 1 +# else +# define XXH_NO_INLINE_HINTS 0 +# endif +#endif + +/*! + * XXH_REROLL: + * Whether to reroll XXH32_finalize, and XXH64_finalize, + * instead of using an unrolled jump table/if statement loop. + * + * This is automatically defined on -Os/-Oz on GCC and Clang. + */ +#ifndef XXH_REROLL +# if defined(__OPTIMIZE_SIZE__) +# define XXH_REROLL 1 +# else +# define XXH_REROLL 0 +# endif +#endif + + +/* ************************************* +* Includes & Memory related functions +***************************************/ +/*! + * Modify the local functions below should you wish to use some other memory + * routines for malloc() and free() + */ +#include + +static void* XXH_malloc(size_t s) { return malloc(s); } +static void XXH_free(void* p) { free(p); } + +/*! and for memcpy() */ +#include +static void* XXH_memcpy(void* dest, const void* src, size_t size) +{ + return memcpy(dest,src,size); +} + +#include /* ULLONG_MAX */ + + +/* ************************************* +* Compiler Specific Options +***************************************/ +#ifdef _MSC_VER /* Visual Studio warning fix */ +# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */ +#endif + +#if XXH_NO_INLINE_HINTS /* disable inlining hints */ +# if defined(__GNUC__) +# define XXH_FORCE_INLINE static __attribute__((unused)) +# else +# define XXH_FORCE_INLINE static +# endif +# define XXH_NO_INLINE static +/* enable inlining hints */ +#elif defined(_MSC_VER) /* Visual Studio */ +# define XXH_FORCE_INLINE static __forceinline +# define XXH_NO_INLINE static __declspec(noinline) +#elif defined(__GNUC__) +# define XXH_FORCE_INLINE static __inline__ __attribute__((always_inline, unused)) +# define XXH_NO_INLINE static __attribute__((noinline)) +#elif defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) /* C99 */ +# define XXH_FORCE_INLINE static inline +# define XXH_NO_INLINE static +#else +# define XXH_FORCE_INLINE static +# define XXH_NO_INLINE static +#endif + + + +/* ************************************* +* Debug +***************************************/ +/* + * XXH_DEBUGLEVEL is expected to be defined externally, typically via the + * compiler's command line options. The value must be a number. + */ +#ifndef XXH_DEBUGLEVEL +# ifdef DEBUGLEVEL /* backwards compat */ +# define XXH_DEBUGLEVEL DEBUGLEVEL +# else +# define XXH_DEBUGLEVEL 0 +# endif +#endif + +#if (XXH_DEBUGLEVEL>=1) +# include /* note: can still be disabled with NDEBUG */ +# define XXH_ASSERT(c) assert(c) +#else +# define XXH_ASSERT(c) ((void)0) +#endif + +/* note: use after variable declarations */ +#define XXH_STATIC_ASSERT(c) do { enum { XXH_sa = 1/(int)(!!(c)) }; } while (0) + + +/* ************************************* +* Basic Types +***************************************/ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# include + typedef uint8_t xxh_u8; +#else + typedef unsigned char xxh_u8; +#endif +typedef XXH32_hash_t xxh_u32; + +#ifdef XXH_OLD_NAMES +# define BYTE xxh_u8 +# define U8 xxh_u8 +# define U32 xxh_u32 +#endif + +/* *** Memory access *** */ + +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) +/* + * Manual byteshift. Best for old compilers which don't inline memcpy. + * We actually directly use XXH_readLE32 and XXH_readBE32. + */ +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2)) + +/* + * Force direct memory access. Only works on CPU which support unaligned memory + * access in hardware. + */ +static xxh_u32 XXH_read32(const void* memPtr) { return *(const xxh_u32*) memPtr; } + +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1)) + +/* + * __pack instructions are safer but compiler specific, hence potentially + * problematic for some compilers. + * + * Currently only defined for GCC and ICC. + */ +#ifdef XXH_OLD_NAMES +typedef union { xxh_u32 u32; } __attribute__((packed)) unalign; +#endif +static xxh_u32 XXH_read32(const void* ptr) +{ + typedef union { xxh_u32 u32; } __attribute__((packed)) xxh_unalign; + return ((const xxh_unalign*)ptr)->u32; +} + +#else + +/* + * Portable and safe solution. Generally efficient. + * see: https://stackoverflow.com/a/32095106/646947 + */ +static xxh_u32 XXH_read32(const void* memPtr) +{ + xxh_u32 val; + memcpy(&val, memPtr, sizeof(val)); + return val; +} + +#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */ + + +/* *** Endianess *** */ +typedef enum { XXH_bigEndian=0, XXH_littleEndian=1 } XXH_endianess; + +/*! + * XXH_CPU_LITTLE_ENDIAN: + * Defined to 1 if the target is little endian, or 0 if it is big endian. + * It can be defined externally, for example on the compiler command line. + * + * If it is not defined, a runtime check (which is usually constant folded) + * is used instead. + */ +#ifndef XXH_CPU_LITTLE_ENDIAN +/* + * Try to detect endianness automatically, to avoid the nonstandard behavior + * in `XXH_isLittleEndian()` + */ +# if defined(_WIN32) /* Windows is always little endian */ \ + || defined(__LITTLE_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +# define XXH_CPU_LITTLE_ENDIAN 1 +# elif defined(__BIG_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +# define XXH_CPU_LITTLE_ENDIAN 0 +# else +/* + * runtime test, presumed to simplify to a constant by compiler + */ +static int XXH_isLittleEndian(void) +{ + /* + * Portable and well-defined behavior. + * Don't use static: it is detrimental to performance. + */ + const union { xxh_u32 u; xxh_u8 c[4]; } one = { 1 }; + return one.c[0]; +} +# define XXH_CPU_LITTLE_ENDIAN XXH_isLittleEndian() +# endif +#endif + + + + +/* **************************************** +* Compiler-specific Functions and Macros +******************************************/ +#define XXH_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#ifdef __has_builtin +# define XXH_HAS_BUILTIN(x) __has_builtin(x) +#else +# define XXH_HAS_BUILTIN(x) 0 +#endif + +#if !defined(NO_CLANG_BUILTIN) && XXH_HAS_BUILTIN(__builtin_rotateleft32) \ + && XXH_HAS_BUILTIN(__builtin_rotateleft64) +# define XXH_rotl32 __builtin_rotateleft32 +# define XXH_rotl64 __builtin_rotateleft64 +/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */ +#elif defined(_MSC_VER) +# define XXH_rotl32(x,r) _rotl(x,r) +# define XXH_rotl64(x,r) _rotl64(x,r) +#else +# define XXH_rotl32(x,r) (((x) << (r)) | ((x) >> (32 - (r)))) +# define XXH_rotl64(x,r) (((x) << (r)) | ((x) >> (64 - (r)))) +#endif + +#if defined(_MSC_VER) /* Visual Studio */ +# define XXH_swap32 _byteswap_ulong +#elif XXH_GCC_VERSION >= 403 +# define XXH_swap32 __builtin_bswap32 +#else +static xxh_u32 XXH_swap32 (xxh_u32 x) +{ + return ((x << 24) & 0xff000000 ) | + ((x << 8) & 0x00ff0000 ) | + ((x >> 8) & 0x0000ff00 ) | + ((x >> 24) & 0x000000ff ); +} +#endif + + +/* *************************** +* Memory reads +*****************************/ +typedef enum { XXH_aligned, XXH_unaligned } XXH_alignment; + +/* + * XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. + * + * This is ideal for older compilers which don't inline memcpy. + */ +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) + +XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[0] + | ((xxh_u32)bytePtr[1] << 8) + | ((xxh_u32)bytePtr[2] << 16) + | ((xxh_u32)bytePtr[3] << 24); +} + +XXH_FORCE_INLINE xxh_u32 XXH_readBE32(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[3] + | ((xxh_u32)bytePtr[2] << 8) + | ((xxh_u32)bytePtr[1] << 16) + | ((xxh_u32)bytePtr[0] << 24); +} + +#else +XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_read32(ptr) : XXH_swap32(XXH_read32(ptr)); +} + +static xxh_u32 XXH_readBE32(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_swap32(XXH_read32(ptr)) : XXH_read32(ptr); +} +#endif + +XXH_FORCE_INLINE xxh_u32 +XXH_readLE32_align(const void* ptr, XXH_alignment align) +{ + if (align==XXH_unaligned) { + return XXH_readLE32(ptr); + } else { + return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u32*)ptr : XXH_swap32(*(const xxh_u32*)ptr); + } +} + + +/* ************************************* +* Misc +***************************************/ +XXH_PUBLIC_API unsigned XXH_versionNumber (void) { return XXH_VERSION_NUMBER; } + + +/* ******************************************************************* +* 32-bit hash functions +*********************************************************************/ +static const xxh_u32 XXH_PRIME32_1 = 0x9E3779B1U; /* 0b10011110001101110111100110110001 */ +static const xxh_u32 XXH_PRIME32_2 = 0x85EBCA77U; /* 0b10000101111010111100101001110111 */ +static const xxh_u32 XXH_PRIME32_3 = 0xC2B2AE3DU; /* 0b11000010101100101010111000111101 */ +static const xxh_u32 XXH_PRIME32_4 = 0x27D4EB2FU; /* 0b00100111110101001110101100101111 */ +static const xxh_u32 XXH_PRIME32_5 = 0x165667B1U; /* 0b00010110010101100110011110110001 */ + +#ifdef XXH_OLD_NAMES +# define PRIME32_1 XXH_PRIME32_1 +# define PRIME32_2 XXH_PRIME32_2 +# define PRIME32_3 XXH_PRIME32_3 +# define PRIME32_4 XXH_PRIME32_4 +# define PRIME32_5 XXH_PRIME32_5 +#endif + +static xxh_u32 XXH32_round(xxh_u32 acc, xxh_u32 input) +{ + acc += input * XXH_PRIME32_2; + acc = XXH_rotl32(acc, 13); + acc *= XXH_PRIME32_1; +#if defined(__GNUC__) && defined(__SSE4_1__) && !defined(XXH_ENABLE_AUTOVECTORIZE) + /* + * UGLY HACK: + * This inline assembly hack forces acc into a normal register. This is the + * only thing that prevents GCC and Clang from autovectorizing the XXH32 + * loop (pragmas and attributes don't work for some resason) without globally + * disabling SSE4.1. + * + * The reason we want to avoid vectorization is because despite working on + * 4 integers at a time, there are multiple factors slowing XXH32 down on + * SSE4: + * - There's a ridiculous amount of lag from pmulld (10 cycles of latency on + * newer chips!) making it slightly slower to multiply four integers at + * once compared to four integers independently. Even when pmulld was + * fastest, Sandy/Ivy Bridge, it is still not worth it to go into SSE + * just to multiply unless doing a long operation. + * + * - Four instructions are required to rotate, + * movqda tmp, v // not required with VEX encoding + * pslld tmp, 13 // tmp <<= 13 + * psrld v, 19 // x >>= 19 + * por v, tmp // x |= tmp + * compared to one for scalar: + * roll v, 13 // reliably fast across the board + * shldl v, v, 13 // Sandy Bridge and later prefer this for some reason + * + * - Instruction level parallelism is actually more beneficial here because + * the SIMD actually serializes this operation: While v1 is rotating, v2 + * can load data, while v3 can multiply. SSE forces them to operate + * together. + * + * How this hack works: + * __asm__("" // Declare an assembly block but don't declare any instructions + * : // However, as an Input/Output Operand, + * "+r" // constrain a read/write operand (+) as a general purpose register (r). + * (acc) // and set acc as the operand + * ); + * + * Because of the 'r', the compiler has promised that seed will be in a + * general purpose register and the '+' says that it will be 'read/write', + * so it has to assume it has changed. It is like volatile without all the + * loads and stores. + * + * Since the argument has to be in a normal register (not an SSE register), + * each time XXH32_round is called, it is impossible to vectorize. + */ + __asm__("" : "+r" (acc)); +#endif + return acc; +} + +/* mix all bits */ +static xxh_u32 XXH32_avalanche(xxh_u32 h32) +{ + h32 ^= h32 >> 15; + h32 *= XXH_PRIME32_2; + h32 ^= h32 >> 13; + h32 *= XXH_PRIME32_3; + h32 ^= h32 >> 16; + return(h32); +} + +#define XXH_get32bits(p) XXH_readLE32_align(p, align) + +static xxh_u32 +XXH32_finalize(xxh_u32 h32, const xxh_u8* ptr, size_t len, XXH_alignment align) +{ +#define XXH_PROCESS1 do { \ + h32 += (*ptr++) * XXH_PRIME32_5; \ + h32 = XXH_rotl32(h32, 11) * XXH_PRIME32_1; \ +} while (0) + +#define XXH_PROCESS4 do { \ + h32 += XXH_get32bits(ptr) * XXH_PRIME32_3; \ + ptr += 4; \ + h32 = XXH_rotl32(h32, 17) * XXH_PRIME32_4; \ +} while (0) + + /* Compact rerolled version */ + if (XXH_REROLL) { + len &= 15; + while (len >= 4) { + XXH_PROCESS4; + len -= 4; + } + while (len > 0) { + XXH_PROCESS1; + --len; + } + return XXH32_avalanche(h32); + } else { + switch(len&15) /* or switch(bEnd - p) */ { + case 12: XXH_PROCESS4; + /* fallthrough */ + case 8: XXH_PROCESS4; + /* fallthrough */ + case 4: XXH_PROCESS4; + return XXH32_avalanche(h32); + + case 13: XXH_PROCESS4; + /* fallthrough */ + case 9: XXH_PROCESS4; + /* fallthrough */ + case 5: XXH_PROCESS4; + XXH_PROCESS1; + return XXH32_avalanche(h32); + + case 14: XXH_PROCESS4; + /* fallthrough */ + case 10: XXH_PROCESS4; + /* fallthrough */ + case 6: XXH_PROCESS4; + XXH_PROCESS1; + XXH_PROCESS1; + return XXH32_avalanche(h32); + + case 15: XXH_PROCESS4; + /* fallthrough */ + case 11: XXH_PROCESS4; + /* fallthrough */ + case 7: XXH_PROCESS4; + /* fallthrough */ + case 3: XXH_PROCESS1; + /* fallthrough */ + case 2: XXH_PROCESS1; + /* fallthrough */ + case 1: XXH_PROCESS1; + /* fallthrough */ + case 0: return XXH32_avalanche(h32); + } + XXH_ASSERT(0); + return h32; /* reaching this point is deemed impossible */ + } +} + +#ifdef XXH_OLD_NAMES +# define PROCESS1 XXH_PROCESS1 +# define PROCESS4 XXH_PROCESS4 +#else +# undef XXH_PROCESS1 +# undef XXH_PROCESS4 +#endif + +XXH_FORCE_INLINE xxh_u32 +XXH32_endian_align(const xxh_u8* input, size_t len, xxh_u32 seed, XXH_alignment align) +{ + const xxh_u8* bEnd = input + len; + xxh_u32 h32; + +#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1) + if (input==NULL) { + len=0; + bEnd=input=(const xxh_u8*)(size_t)16; + } +#endif + + if (len>=16) { + const xxh_u8* const limit = bEnd - 15; + xxh_u32 v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2; + xxh_u32 v2 = seed + XXH_PRIME32_2; + xxh_u32 v3 = seed + 0; + xxh_u32 v4 = seed - XXH_PRIME32_1; + + do { + v1 = XXH32_round(v1, XXH_get32bits(input)); input += 4; + v2 = XXH32_round(v2, XXH_get32bits(input)); input += 4; + v3 = XXH32_round(v3, XXH_get32bits(input)); input += 4; + v4 = XXH32_round(v4, XXH_get32bits(input)); input += 4; + } while (input < limit); + + h32 = XXH_rotl32(v1, 1) + XXH_rotl32(v2, 7) + + XXH_rotl32(v3, 12) + XXH_rotl32(v4, 18); + } else { + h32 = seed + XXH_PRIME32_5; + } + + h32 += (xxh_u32)len; + + return XXH32_finalize(h32, input, len&15, align); +} + + +XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t len, XXH32_hash_t seed) +{ +#if 0 + /* Simple version, good for code maintenance, but unfortunately slow for small inputs */ + XXH32_state_t state; + XXH32_reset(&state, seed); + XXH32_update(&state, (const xxh_u8*)input, len); + return XXH32_digest(&state); + +#else + + if (XXH_FORCE_ALIGN_CHECK) { + if ((((size_t)input) & 3) == 0) { /* Input is 4-bytes aligned, leverage the speed benefit */ + return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_aligned); + } } + + return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned); +#endif +} + + + +/******* Hash streaming *******/ + +XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void) +{ + return (XXH32_state_t*)XXH_malloc(sizeof(XXH32_state_t)); +} +XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr) +{ + XXH_free(statePtr); + return XXH_OK; +} + +XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dstState, const XXH32_state_t* srcState) +{ + memcpy(dstState, srcState, sizeof(*dstState)); +} + +XXH_PUBLIC_API XXH_errorcode XXH32_reset(XXH32_state_t* statePtr, XXH32_hash_t seed) +{ + XXH32_state_t state; /* using a local state to memcpy() in order to avoid strict-aliasing warnings */ + memset(&state, 0, sizeof(state)); + state.v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2; + state.v2 = seed + XXH_PRIME32_2; + state.v3 = seed + 0; + state.v4 = seed - XXH_PRIME32_1; + /* do not write into reserved, planned to be removed in a future version */ + memcpy(statePtr, &state, sizeof(state) - sizeof(state.reserved)); + return XXH_OK; +} + + +XXH_PUBLIC_API XXH_errorcode +XXH32_update(XXH32_state_t* state, const void* input, size_t len) +{ + if (input==NULL) +#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1) + return XXH_OK; +#else + return XXH_ERROR; +#endif + + { const xxh_u8* p = (const xxh_u8*)input; + const xxh_u8* const bEnd = p + len; + + state->total_len_32 += (XXH32_hash_t)len; + state->large_len |= (XXH32_hash_t)((len>=16) | (state->total_len_32>=16)); + + if (state->memsize + len < 16) { /* fill in tmp buffer */ + XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, len); + state->memsize += (XXH32_hash_t)len; + return XXH_OK; + } + + if (state->memsize) { /* some data left from previous update */ + XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, 16-state->memsize); + { const xxh_u32* p32 = state->mem32; + state->v1 = XXH32_round(state->v1, XXH_readLE32(p32)); p32++; + state->v2 = XXH32_round(state->v2, XXH_readLE32(p32)); p32++; + state->v3 = XXH32_round(state->v3, XXH_readLE32(p32)); p32++; + state->v4 = XXH32_round(state->v4, XXH_readLE32(p32)); + } + p += 16-state->memsize; + state->memsize = 0; + } + + if (p <= bEnd-16) { + const xxh_u8* const limit = bEnd - 16; + xxh_u32 v1 = state->v1; + xxh_u32 v2 = state->v2; + xxh_u32 v3 = state->v3; + xxh_u32 v4 = state->v4; + + do { + v1 = XXH32_round(v1, XXH_readLE32(p)); p+=4; + v2 = XXH32_round(v2, XXH_readLE32(p)); p+=4; + v3 = XXH32_round(v3, XXH_readLE32(p)); p+=4; + v4 = XXH32_round(v4, XXH_readLE32(p)); p+=4; + } while (p<=limit); + + state->v1 = v1; + state->v2 = v2; + state->v3 = v3; + state->v4 = v4; + } + + if (p < bEnd) { + XXH_memcpy(state->mem32, p, (size_t)(bEnd-p)); + state->memsize = (unsigned)(bEnd-p); + } + } + + return XXH_OK; +} + + +XXH_PUBLIC_API XXH32_hash_t XXH32_digest (const XXH32_state_t* state) +{ + xxh_u32 h32; + + if (state->large_len) { + h32 = XXH_rotl32(state->v1, 1) + + XXH_rotl32(state->v2, 7) + + XXH_rotl32(state->v3, 12) + + XXH_rotl32(state->v4, 18); + } else { + h32 = state->v3 /* == seed */ + XXH_PRIME32_5; + } + + h32 += state->total_len_32; + + return XXH32_finalize(h32, (const xxh_u8*)state->mem32, state->memsize, XXH_aligned); +} + + +/******* Canonical representation *******/ + +/* + * The default return values from XXH functions are unsigned 32 and 64 bit + * integers. + * + * The canonical representation uses big endian convention, the same convention + * as human-readable numbers (large digits first). + * + * This way, hash values can be written into a file or buffer, remaining + * comparable across different systems. + * + * The following functions allow transformation of hash values to and from their + * canonical format. + */ +XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH32_canonical_t) == sizeof(XXH32_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap32(hash); + memcpy(dst, &hash, sizeof(*dst)); +} + +XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src) +{ + return XXH_readBE32(src); +} + + +#ifndef XXH_NO_LONG_LONG + +/* ******************************************************************* +* 64-bit hash functions +*********************************************************************/ + +/******* Memory access *******/ + +typedef XXH64_hash_t xxh_u64; + +#ifdef XXH_OLD_NAMES +# define U64 xxh_u64 +#endif + +/*! + * XXH_REROLL_XXH64: + * Whether to reroll the XXH64_finalize() loop. + * + * Just like XXH32, we can unroll the XXH64_finalize() loop. This can be a + * performance gain on 64-bit hosts, as only one jump is required. + * + * However, on 32-bit hosts, because arithmetic needs to be done with two 32-bit + * registers, and 64-bit arithmetic needs to be simulated, it isn't beneficial + * to unroll. The code becomes ridiculously large (the largest function in the + * binary on i386!), and rerolling it saves anywhere from 3kB to 20kB. It is + * also slightly faster because it fits into cache better and is more likely + * to be inlined by the compiler. + * + * If XXH_REROLL is defined, this is ignored and the loop is always rerolled. + */ +#ifndef XXH_REROLL_XXH64 +# if (defined(__ILP32__) || defined(_ILP32)) /* ILP32 is often defined on 32-bit GCC family */ \ + || !(defined(__x86_64__) || defined(_M_X64) || defined(_M_AMD64) /* x86-64 */ \ + || defined(_M_ARM64) || defined(__aarch64__) || defined(__arm64__) /* aarch64 */ \ + || defined(__PPC64__) || defined(__PPC64LE__) || defined(__ppc64__) || defined(__powerpc64__) /* ppc64 */ \ + || defined(__mips64__) || defined(__mips64)) /* mips64 */ \ + || (!defined(SIZE_MAX) || SIZE_MAX < ULLONG_MAX) /* check limits */ +# define XXH_REROLL_XXH64 1 +# else +# define XXH_REROLL_XXH64 0 +# endif +#endif /* !defined(XXH_REROLL_XXH64) */ + +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) +/* + * Manual byteshift. Best for old compilers which don't inline memcpy. + * We actually directly use XXH_readLE64 and XXH_readBE64. + */ +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2)) + +/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */ +static xxh_u64 XXH_read64(const void* memPtr) { return *(const xxh_u64*) memPtr; } + +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1)) + +/* + * __pack instructions are safer, but compiler specific, hence potentially + * problematic for some compilers. + * + * Currently only defined for GCC and ICC. + */ +#ifdef XXH_OLD_NAMES +typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) unalign64; +#endif +static xxh_u64 XXH_read64(const void* ptr) +{ + typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((packed)) xxh_unalign64; + return ((const xxh_unalign64*)ptr)->u64; +} + +#else + +/* + * Portable and safe solution. Generally efficient. + * see: https://stackoverflow.com/a/32095106/646947 + */ +static xxh_u64 XXH_read64(const void* memPtr) +{ + xxh_u64 val; + memcpy(&val, memPtr, sizeof(val)); + return val; +} + +#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */ + +#if defined(_MSC_VER) /* Visual Studio */ +# define XXH_swap64 _byteswap_uint64 +#elif XXH_GCC_VERSION >= 403 +# define XXH_swap64 __builtin_bswap64 +#else +static xxh_u64 XXH_swap64 (xxh_u64 x) +{ + return ((x << 56) & 0xff00000000000000ULL) | + ((x << 40) & 0x00ff000000000000ULL) | + ((x << 24) & 0x0000ff0000000000ULL) | + ((x << 8) & 0x000000ff00000000ULL) | + ((x >> 8) & 0x00000000ff000000ULL) | + ((x >> 24) & 0x0000000000ff0000ULL) | + ((x >> 40) & 0x000000000000ff00ULL) | + ((x >> 56) & 0x00000000000000ffULL); +} +#endif + + +/* XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. */ +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) + +XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[0] + | ((xxh_u64)bytePtr[1] << 8) + | ((xxh_u64)bytePtr[2] << 16) + | ((xxh_u64)bytePtr[3] << 24) + | ((xxh_u64)bytePtr[4] << 32) + | ((xxh_u64)bytePtr[5] << 40) + | ((xxh_u64)bytePtr[6] << 48) + | ((xxh_u64)bytePtr[7] << 56); +} + +XXH_FORCE_INLINE xxh_u64 XXH_readBE64(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[7] + | ((xxh_u64)bytePtr[6] << 8) + | ((xxh_u64)bytePtr[5] << 16) + | ((xxh_u64)bytePtr[4] << 24) + | ((xxh_u64)bytePtr[3] << 32) + | ((xxh_u64)bytePtr[2] << 40) + | ((xxh_u64)bytePtr[1] << 48) + | ((xxh_u64)bytePtr[0] << 56); +} + +#else +XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_read64(ptr) : XXH_swap64(XXH_read64(ptr)); +} + +static xxh_u64 XXH_readBE64(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_swap64(XXH_read64(ptr)) : XXH_read64(ptr); +} +#endif + +XXH_FORCE_INLINE xxh_u64 +XXH_readLE64_align(const void* ptr, XXH_alignment align) +{ + if (align==XXH_unaligned) + return XXH_readLE64(ptr); + else + return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u64*)ptr : XXH_swap64(*(const xxh_u64*)ptr); +} + + +/******* xxh64 *******/ + +static const xxh_u64 XXH_PRIME64_1 = 0x9E3779B185EBCA87ULL; /* 0b1001111000110111011110011011000110000101111010111100101010000111 */ +static const xxh_u64 XXH_PRIME64_2 = 0xC2B2AE3D27D4EB4FULL; /* 0b1100001010110010101011100011110100100111110101001110101101001111 */ +static const xxh_u64 XXH_PRIME64_3 = 0x165667B19E3779F9ULL; /* 0b0001011001010110011001111011000110011110001101110111100111111001 */ +static const xxh_u64 XXH_PRIME64_4 = 0x85EBCA77C2B2AE63ULL; /* 0b1000010111101011110010100111011111000010101100101010111001100011 */ +static const xxh_u64 XXH_PRIME64_5 = 0x27D4EB2F165667C5ULL; /* 0b0010011111010100111010110010111100010110010101100110011111000101 */ + +#ifdef XXH_OLD_NAMES +# define PRIME64_1 XXH_PRIME64_1 +# define PRIME64_2 XXH_PRIME64_2 +# define PRIME64_3 XXH_PRIME64_3 +# define PRIME64_4 XXH_PRIME64_4 +# define PRIME64_5 XXH_PRIME64_5 +#endif + +static xxh_u64 XXH64_round(xxh_u64 acc, xxh_u64 input) +{ + acc += input * XXH_PRIME64_2; + acc = XXH_rotl64(acc, 31); + acc *= XXH_PRIME64_1; + return acc; +} + +static xxh_u64 XXH64_mergeRound(xxh_u64 acc, xxh_u64 val) +{ + val = XXH64_round(0, val); + acc ^= val; + acc = acc * XXH_PRIME64_1 + XXH_PRIME64_4; + return acc; +} + +static xxh_u64 XXH64_avalanche(xxh_u64 h64) +{ + h64 ^= h64 >> 33; + h64 *= XXH_PRIME64_2; + h64 ^= h64 >> 29; + h64 *= XXH_PRIME64_3; + h64 ^= h64 >> 32; + return h64; +} + + +#define XXH_get64bits(p) XXH_readLE64_align(p, align) + +static xxh_u64 +XXH64_finalize(xxh_u64 h64, const xxh_u8* ptr, size_t len, XXH_alignment align) +{ +#define XXH_PROCESS1_64 do { \ + h64 ^= (*ptr++) * XXH_PRIME64_5; \ + h64 = XXH_rotl64(h64, 11) * XXH_PRIME64_1; \ +} while (0) + +#define XXH_PROCESS4_64 do { \ + h64 ^= (xxh_u64)(XXH_get32bits(ptr)) * XXH_PRIME64_1; \ + ptr += 4; \ + h64 = XXH_rotl64(h64, 23) * XXH_PRIME64_2 + XXH_PRIME64_3; \ +} while (0) + +#define XXH_PROCESS8_64 do { \ + xxh_u64 const k1 = XXH64_round(0, XXH_get64bits(ptr)); \ + ptr += 8; \ + h64 ^= k1; \ + h64 = XXH_rotl64(h64,27) * XXH_PRIME64_1 + XXH_PRIME64_4; \ +} while (0) + + /* Rerolled version for 32-bit targets is faster and much smaller. */ + if (XXH_REROLL || XXH_REROLL_XXH64) { + len &= 31; + while (len >= 8) { + XXH_PROCESS8_64; + len -= 8; + } + if (len >= 4) { + XXH_PROCESS4_64; + len -= 4; + } + while (len > 0) { + XXH_PROCESS1_64; + --len; + } + return XXH64_avalanche(h64); + } else { + switch(len & 31) { + case 24: XXH_PROCESS8_64; + /* fallthrough */ + case 16: XXH_PROCESS8_64; + /* fallthrough */ + case 8: XXH_PROCESS8_64; + return XXH64_avalanche(h64); + + case 28: XXH_PROCESS8_64; + /* fallthrough */ + case 20: XXH_PROCESS8_64; + /* fallthrough */ + case 12: XXH_PROCESS8_64; + /* fallthrough */ + case 4: XXH_PROCESS4_64; + return XXH64_avalanche(h64); + + case 25: XXH_PROCESS8_64; + /* fallthrough */ + case 17: XXH_PROCESS8_64; + /* fallthrough */ + case 9: XXH_PROCESS8_64; + XXH_PROCESS1_64; + return XXH64_avalanche(h64); + + case 29: XXH_PROCESS8_64; + /* fallthrough */ + case 21: XXH_PROCESS8_64; + /* fallthrough */ + case 13: XXH_PROCESS8_64; + /* fallthrough */ + case 5: XXH_PROCESS4_64; + XXH_PROCESS1_64; + return XXH64_avalanche(h64); + + case 26: XXH_PROCESS8_64; + /* fallthrough */ + case 18: XXH_PROCESS8_64; + /* fallthrough */ + case 10: XXH_PROCESS8_64; + XXH_PROCESS1_64; + XXH_PROCESS1_64; + return XXH64_avalanche(h64); + + case 30: XXH_PROCESS8_64; + /* fallthrough */ + case 22: XXH_PROCESS8_64; + /* fallthrough */ + case 14: XXH_PROCESS8_64; + /* fallthrough */ + case 6: XXH_PROCESS4_64; + XXH_PROCESS1_64; + XXH_PROCESS1_64; + return XXH64_avalanche(h64); + + case 27: XXH_PROCESS8_64; + /* fallthrough */ + case 19: XXH_PROCESS8_64; + /* fallthrough */ + case 11: XXH_PROCESS8_64; + XXH_PROCESS1_64; + XXH_PROCESS1_64; + XXH_PROCESS1_64; + return XXH64_avalanche(h64); + + case 31: XXH_PROCESS8_64; + /* fallthrough */ + case 23: XXH_PROCESS8_64; + /* fallthrough */ + case 15: XXH_PROCESS8_64; + /* fallthrough */ + case 7: XXH_PROCESS4_64; + /* fallthrough */ + case 3: XXH_PROCESS1_64; + /* fallthrough */ + case 2: XXH_PROCESS1_64; + /* fallthrough */ + case 1: XXH_PROCESS1_64; + /* fallthrough */ + case 0: return XXH64_avalanche(h64); + } + } + /* impossible to reach */ + XXH_ASSERT(0); + return 0; /* unreachable, but some compilers complain without it */ +} + +#ifdef XXH_OLD_NAMES +# define PROCESS1_64 XXH_PROCESS1_64 +# define PROCESS4_64 XXH_PROCESS4_64 +# define PROCESS8_64 XXH_PROCESS8_64 +#else +# undef XXH_PROCESS1_64 +# undef XXH_PROCESS4_64 +# undef XXH_PROCESS8_64 +#endif + +XXH_FORCE_INLINE xxh_u64 +XXH64_endian_align(const xxh_u8* input, size_t len, xxh_u64 seed, XXH_alignment align) +{ + const xxh_u8* bEnd = input + len; + xxh_u64 h64; + +#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1) + if (input==NULL) { + len=0; + bEnd=input=(const xxh_u8*)(size_t)32; + } +#endif + + if (len>=32) { + const xxh_u8* const limit = bEnd - 32; + xxh_u64 v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2; + xxh_u64 v2 = seed + XXH_PRIME64_2; + xxh_u64 v3 = seed + 0; + xxh_u64 v4 = seed - XXH_PRIME64_1; + + do { + v1 = XXH64_round(v1, XXH_get64bits(input)); input+=8; + v2 = XXH64_round(v2, XXH_get64bits(input)); input+=8; + v3 = XXH64_round(v3, XXH_get64bits(input)); input+=8; + v4 = XXH64_round(v4, XXH_get64bits(input)); input+=8; + } while (input<=limit); + + h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18); + h64 = XXH64_mergeRound(h64, v1); + h64 = XXH64_mergeRound(h64, v2); + h64 = XXH64_mergeRound(h64, v3); + h64 = XXH64_mergeRound(h64, v4); + + } else { + h64 = seed + XXH_PRIME64_5; + } + + h64 += (xxh_u64) len; + + return XXH64_finalize(h64, input, len, align); +} + + +XXH_PUBLIC_API XXH64_hash_t XXH64 (const void* input, size_t len, XXH64_hash_t seed) +{ +#if 0 + /* Simple version, good for code maintenance, but unfortunately slow for small inputs */ + XXH64_state_t state; + XXH64_reset(&state, seed); + XXH64_update(&state, (const xxh_u8*)input, len); + return XXH64_digest(&state); + +#else + + if (XXH_FORCE_ALIGN_CHECK) { + if ((((size_t)input) & 7)==0) { /* Input is aligned, let's leverage the speed advantage */ + return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_aligned); + } } + + return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned); + +#endif +} + +/******* Hash Streaming *******/ + +XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void) +{ + return (XXH64_state_t*)XXH_malloc(sizeof(XXH64_state_t)); +} +XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr) +{ + XXH_free(statePtr); + return XXH_OK; +} + +XXH_PUBLIC_API void XXH64_copyState(XXH64_state_t* dstState, const XXH64_state_t* srcState) +{ + memcpy(dstState, srcState, sizeof(*dstState)); +} + +XXH_PUBLIC_API XXH_errorcode XXH64_reset(XXH64_state_t* statePtr, XXH64_hash_t seed) +{ + XXH64_state_t state; /* use a local state to memcpy() in order to avoid strict-aliasing warnings */ + memset(&state, 0, sizeof(state)); + state.v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2; + state.v2 = seed + XXH_PRIME64_2; + state.v3 = seed + 0; + state.v4 = seed - XXH_PRIME64_1; + /* do not write into reserved64, might be removed in a future version */ + memcpy(statePtr, &state, sizeof(state) - sizeof(state.reserved64)); + return XXH_OK; +} + +XXH_PUBLIC_API XXH_errorcode +XXH64_update (XXH64_state_t* state, const void* input, size_t len) +{ + if (input==NULL) +#if defined(XXH_ACCEPT_NULL_INPUT_POINTER) && (XXH_ACCEPT_NULL_INPUT_POINTER>=1) + return XXH_OK; +#else + return XXH_ERROR; +#endif + + { const xxh_u8* p = (const xxh_u8*)input; + const xxh_u8* const bEnd = p + len; + + state->total_len += len; + + if (state->memsize + len < 32) { /* fill in tmp buffer */ + XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, len); + state->memsize += (xxh_u32)len; + return XXH_OK; + } + + if (state->memsize) { /* tmp buffer is full */ + XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, 32-state->memsize); + state->v1 = XXH64_round(state->v1, XXH_readLE64(state->mem64+0)); + state->v2 = XXH64_round(state->v2, XXH_readLE64(state->mem64+1)); + state->v3 = XXH64_round(state->v3, XXH_readLE64(state->mem64+2)); + state->v4 = XXH64_round(state->v4, XXH_readLE64(state->mem64+3)); + p += 32-state->memsize; + state->memsize = 0; + } + + if (p+32 <= bEnd) { + const xxh_u8* const limit = bEnd - 32; + xxh_u64 v1 = state->v1; + xxh_u64 v2 = state->v2; + xxh_u64 v3 = state->v3; + xxh_u64 v4 = state->v4; + + do { + v1 = XXH64_round(v1, XXH_readLE64(p)); p+=8; + v2 = XXH64_round(v2, XXH_readLE64(p)); p+=8; + v3 = XXH64_round(v3, XXH_readLE64(p)); p+=8; + v4 = XXH64_round(v4, XXH_readLE64(p)); p+=8; + } while (p<=limit); + + state->v1 = v1; + state->v2 = v2; + state->v3 = v3; + state->v4 = v4; + } + + if (p < bEnd) { + XXH_memcpy(state->mem64, p, (size_t)(bEnd-p)); + state->memsize = (unsigned)(bEnd-p); + } + } + + return XXH_OK; +} + + +XXH_PUBLIC_API XXH64_hash_t XXH64_digest (const XXH64_state_t* state) +{ + xxh_u64 h64; + + if (state->total_len >= 32) { + xxh_u64 const v1 = state->v1; + xxh_u64 const v2 = state->v2; + xxh_u64 const v3 = state->v3; + xxh_u64 const v4 = state->v4; + + h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18); + h64 = XXH64_mergeRound(h64, v1); + h64 = XXH64_mergeRound(h64, v2); + h64 = XXH64_mergeRound(h64, v3); + h64 = XXH64_mergeRound(h64, v4); + } else { + h64 = state->v3 /*seed*/ + XXH_PRIME64_5; + } + + h64 += (xxh_u64) state->total_len; + + return XXH64_finalize(h64, (const xxh_u8*)state->mem64, (size_t)state->total_len, XXH_aligned); +} + + +/******* Canonical representation *******/ + +XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH64_canonical_t* dst, XXH64_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH64_canonical_t) == sizeof(XXH64_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap64(hash); + memcpy(dst, &hash, sizeof(*dst)); +} + +XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(const XXH64_canonical_t* src) +{ + return XXH_readBE64(src); +} + + + +/* ********************************************************************* +* XXH3 +* New generation hash designed for speed on small keys and vectorization +************************************************************************ */ + +#include "xxh3.h" + + +#endif /* XXH_NO_LONG_LONG */ + + +#endif /* XXH_IMPLEMENTATION */ + + +#if defined (__cplusplus) +} +#endif \ No newline at end of file diff --git a/src/lib.cpp b/src/lib.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b5b73771cf1d8ff60f004751882247f870e6aa3 --- /dev/null +++ b/src/lib.cpp @@ -0,0 +1,302 @@ +#pragma GCC diagnostic ignored "-Wunused-parameter" + +#include +#include +#include +#include +namespace py = pybind11; + +#include "common/encoder/encoder_all.h" + + +#include "common/midi_parsing/midi_io.h" +#include "./inference/dataset/jagged.h" +#include "./inference/enum/model_type.h" +#include "./inference/enum/encoder_types.h" +#include "./inference/sampling/control.h" +#include "./inference/sampling/callback_base.h" +#include "./inference/version.h" + +#include "./common/midi_parsing/feature_extraction.h" + +#ifndef NO_TORCH +#include "./inference/sampling/sample_internal.h" +#include "./inference/sampling/multi_step_sample.h" +#endif + +#include +#include +#include "../include/dataset_creation/dataset_manipulation/bytes_to_file.h" +#include "../libraries/protobuf/include/proto_library.h" +#include "../libraries/torch/include/torch_library.h" +#include "../libraries/protobuf/build/midi.pb.h" +#include "MidiFile.h" +#include "./common/data_structures/train_config.h" +#include "./lib_encoder.h" + +// ====================== + +namespace midigpt { // you can probably remove this namespace +std::string generate_py(std::string &status_str, std::string &piece_str, std::string ¶m_str) { + midi::Piece piece; + google::protobuf::util::JsonStringToMessage(piece_str.c_str(), &piece); + midi::Status status; + google::protobuf::util::JsonStringToMessage(status_str.c_str(), &status); + midi::HyperParam param; + google::protobuf::util::JsonStringToMessage(param_str.c_str(), ¶m); + #ifndef NO_TORCH + sampling::sample(&piece, &status, ¶m, NULL); + #endif + + std::string output_str; + google::protobuf::util::MessageToJsonString(piece, &output_str); + return output_str; +} +} + +// MAYBE THESE SHOULD GO IN A SEPARATE FILE FOR PYTHON WRAPPERS +midi::Piece string_to_piece(std::string json_string) { + midi::Piece x; + google::protobuf::util::JsonStringToMessage(json_string.c_str(), &x); + return x; +} + +std::string piece_to_string(midi::Piece x) { + std::string json_string; + google::protobuf::util::MessageToJsonString(x, &json_string); + return json_string; +} + +std::string select_random_segment_py(std::string json_string, int num_bars, int min_tracks, int max_tracks, int seed) { + std::mt19937 engine(seed); + midi::Piece x; + util_protobuf::string_to_protobuf(json_string, &x); + util_protobuf::select_random_segment(&x, num_bars, min_tracks, max_tracks, &engine); + return util_protobuf::protobuf_to_string(&x); +} +// MAYBE THESE SHOULD GO IN A SEPARATE FILE FOR PYTHON WRAPPERS + + +py::bytes midi_to_json_bytes(std::string &filepath, data_structures::TrainConfig *tc, std::string &metadata_labels) { + std::string x; + midi::Piece p; + auto config = std::make_shared(); + config->resolution = tc->resolution; + config->decode_resolution = tc->decode_resolution; + config->delta_resolution = tc->delta_resolution; + config->use_microtiming = tc->use_microtiming; + midi_io::ParseSong(filepath, &p, config); + util_protobuf::UpdateValidSegments(&p, tc->num_bars, tc->min_tracks); + if (!p.internal_valid_segments_size()) { + return py::bytes(x); // empty bytes + } + + // insert metadata labels here + midi::MetadataLabels *ml = new midi::MetadataLabels(); + google::protobuf::util::JsonStringToMessage(metadata_labels, ml); + p.set_allocated_internal_metadata_labels(ml); + + p.SerializeToString(&x); + return py::bytes(x); +} + +std::string json_bytes_to_string(py::bytes &json_bytes) { + midi::Piece p; + p.ParseFromString(json_bytes); + return util_protobuf::protobuf_to_string(&p); +} + + +PYBIND11_MODULE(midigpt,handle) { + + handle.def("select_random_segment", &select_random_segment_py); + handle.def("status_from_piece", &util_protobuf::status_from_piece_py); + handle.def("default_sample_param", &util_protobuf::default_sample_param_py); + handle.def("prune_tracks", &util_protobuf::prune_tracks_py); + + handle.def("version", &version); + handle.def("getEncoderSize", &enums::getEncoderSize); + handle.def("getEncoderType", &enums::getEncoderType); + handle.def("getEncoder", &enums::getEncoder); + handle.def("getEncoderTypeList", &enums::getEncoderTypeList); + handle.def("getAttributeControlStr", &encoder::getAttributeControlStr); + +#ifndef NO_TORCH + handle.def("sample_multi_step", &sampling::sample_multi_step_py); + handle.def("sample_multi_step_capture_output", [](std::string piece_json, std::string status_json, std::string param_json, int max_attempts, sampling::CallbackManager *callbacks) { + py::scoped_ostream_redirect stream( + std::cout, + py::module_::import("sys").attr("stdout") // Python output + ); + return sampling::sample_multi_step_py(piece_json, status_json, param_json, max_attempts, callbacks); + }); + handle.def("get_notes", &sampling::get_notes_py); +#endif + + handle.def("compute_all_attribute_controls", &encoder::compute_all_attribute_controls_py); + handle.def("get_instruments_by_category", &enums::get_instruments_by_category); + handle.def("get_instrument_and_track_type_from_gm_inst", &enums::get_instrument_and_track_type_from_gm_inst); + handle.def("midi_to_json_bytes", &midi_to_json_bytes); + handle.def("json_bytes_to_string", &json_bytes_to_string); + + py::enum_(handle, "MODEL_TYPE", py::arithmetic()) + .value("TRACK_MODEL", enums::MODEL_TYPE::TRACK_MODEL) + .value("BAR_INFILL_MODEL", enums::MODEL_TYPE::BAR_INFILL_MODEL) + .export_values(); + + py::class_(handle, "Jagged") + .def(py::init()) + .def("set_seed", &compression::Jagged::set_seed) + .def("set_num_bars", &compression::Jagged::set_num_bars) + .def("set_min_tracks", &compression::Jagged::set_min_tracks) + .def("set_max_tracks", &compression::Jagged::set_max_tracks) + .def("set_max_seq_len", &compression::Jagged::set_max_seq_len) + .def("enable_write", &compression::Jagged::enable_write) + .def("enable_read", &compression::Jagged::enable_read) + .def("append", &compression::Jagged::append) + .def("read", &compression::Jagged::read) + .def("read_bytes", &compression::Jagged::read_bytes) + .def("read_json", &compression::Jagged::read_json) + .def("read_batch", &compression::Jagged::read_batch) + .def("load_random_piece", &compression::Jagged::load_random_piece_py) + .def("load_piece", &compression::Jagged::load_piece) + .def("close", &compression::Jagged::close) + .def("get_size", &compression::Jagged::get_size) + .def("get_split_size", &compression::Jagged::get_split_size); + + py::class_(handle, "TrainConfig") + .def(py::init<>()) + .def_readwrite("num_bars", &data_structures::TrainConfig::num_bars) + .def_readwrite("min_tracks", &data_structures::TrainConfig::min_tracks) + .def_readwrite("max_tracks", &data_structures::TrainConfig::max_tracks) + .def_readwrite("max_mask_percentage", &data_structures::TrainConfig::max_mask_percentage) + .def_readwrite("no_max_length", &data_structures::TrainConfig::no_max_length) + .def_readwrite("resolution", &data_structures::TrainConfig::resolution) + .def_readwrite("use_microtiming", &data_structures::TrainConfig::use_microtiming) + .def_readwrite("microtiming", &data_structures::TrainConfig::microtiming) + .def_readwrite("decode_resolution", &data_structures::TrainConfig::decode_resolution) + .def_readwrite("delta_resolution", &data_structures::TrainConfig::delta_resolution) + .def("to_json", &data_structures::TrainConfig::ToJson) + .def("from_json", &data_structures::TrainConfig::FromJson); + + py::class_>(handle, "REPRESENTATION") + .def(py::init>>()) + .def("decode", &encoder::REPRESENTATION::decode) + .def("is_token_type", &encoder::REPRESENTATION::is_token_type) + .def("in_domain", &encoder::REPRESENTATION::in_domain) + .def("encode", &encoder::REPRESENTATION::encode) + .def("encode_partial", &encoder::REPRESENTATION::encode_partial_py_int) + .def("encode_to_one_hot", &encoder::REPRESENTATION::encode_to_one_hot) + .def("pretty", &encoder::REPRESENTATION::pretty) + .def_readonly("vocab_size", &encoder::REPRESENTATION::vocab_size) + .def("get_type_mask", &encoder::REPRESENTATION::get_type_mask) + .def("max_token", &encoder::REPRESENTATION::max_token) + .def_readonly("token_domains", &encoder::REPRESENTATION::token_domains); + + py::class_(handle, "TOKEN_DOMAIN") + .def(py::init()); + +py::class_>(handle, "EncoderConfig") + .def(py::init<>()) + .def("ToJson", &data_structures::EncoderConfig::ToJson) + .def("FromJson", &data_structures::EncoderConfig::FromJson) + .def_readwrite("both_in_one", &data_structures::EncoderConfig::both_in_one) + .def_readwrite("unquantized", &data_structures::EncoderConfig::unquantized) + .def_readwrite("do_multi_fill", &data_structures::EncoderConfig::do_multi_fill) + + .def_readwrite("use_velocity_levels", &data_structures::EncoderConfig::use_velocity_levels) + .def_readwrite("use_microtiming", &data_structures::EncoderConfig::use_microtiming) + .def_readwrite("transpose", &data_structures::EncoderConfig::transpose) + .def_readwrite("resolution", &data_structures::EncoderConfig::resolution) + .def_readwrite("decode_resolution", &data_structures::EncoderConfig::decode_resolution) + .def_readwrite("decode_final", &data_structures::EncoderConfig::decode_final) + .def_readwrite("delta_resolution", &data_structures::EncoderConfig::delta_resolution) + .def_readwrite("multi_fill", &data_structures::EncoderConfig::multi_fill); + +py::enum_(handle, "TOKEN_TYPE", py::arithmetic()) + .value("PIECE_START", midi::TOKEN_PIECE_START) + .value("NOTE_ONSET", midi::TOKEN_NOTE_ONSET) + .value("PITCH", midi::TOKEN_PITCH) + .value("VELOCITY", midi::TOKEN_VELOCITY) + .value("DELTA", midi::TOKEN_DELTA) + .value("DELTA_DIRECTION", midi::TOKEN_DELTA_DIRECTION) + .value("TIME_ABSOLUTE_POS", midi::TOKEN_TIME_ABSOLUTE_POS) + .value("INSTRUMENT", midi::TOKEN_INSTRUMENT) + .value("BAR", midi::TOKEN_BAR) + .value("BAR_END", midi::TOKEN_BAR_END) + .value("TRACK", midi::TOKEN_TRACK) + .value("TRACK_END", midi::TOKEN_TRACK_END) + .value("DRUM_TRACK", midi::TOKEN_DRUM_TRACK) + .value("FILL_IN", midi::TOKEN_FILL_IN) + .value("FILL_IN_PLACEHOLDER", midi::TOKEN_FILL_IN_PLACEHOLDER) + .value("FILL_IN_START", midi::TOKEN_FILL_IN_START) + .value("FILL_IN_END", midi::TOKEN_FILL_IN_END) + .value("VELOCITY_LEVEL", midi::TOKEN_VELOCITY_LEVEL) + .value("GENRE", midi::TOKEN_GENRE) + .value("DENSITY_LEVEL", midi::TOKEN_DENSITY_LEVEL) + .value("TIME_SIGNATURE", midi::TOKEN_TIME_SIGNATURE) + .value("NOTE_DURATION", midi::TOKEN_NOTE_DURATION) + .value("AV_POLYPHONY", midi::TOKEN_AV_POLYPHONY) + .value("MIN_POLYPHONY", midi::TOKEN_MIN_POLYPHONY) + .value("MAX_POLYPHONY", midi::TOKEN_MAX_POLYPHONY) + .value("MIN_NOTE_DURATION", midi::TOKEN_MIN_NOTE_DURATION) + .value("MAX_NOTE_DURATION", midi::TOKEN_MAX_NOTE_DURATION) + .value("NUM_BARS", midi::TOKEN_NUM_BARS) + .value("MIN_POLYPHONY_HARD", midi::TOKEN_MIN_POLYPHONY_HARD) + .value("MAX_POLYPHONY_HARD", midi::TOKEN_MAX_POLYPHONY_HARD) + .value("MIN_NOTE_DURATION_HARD", midi::TOKEN_MIN_NOTE_DURATION_HARD) + .value("MAX_NOTE_DURATION_HARD", midi::TOKEN_MAX_NOTE_DURATION_HARD) + .value("NONE", midi::TOKEN_NONE) + .export_values(); + + + +// ========================================================= +// ========================================================= +// ENCODERS +// ========================================================= +// ========================================================= +init_encoders(handle); + +// +// ========================================================= +// ========================================================= +// DATASET CREATION +// ========================================================= +// ========================================================= + +//dataset_manipulation folder definitions +py::class_(handle, "BytesToFile") +.def(py::init()) +.def("append_bytes_to_file_stream", &dataset_manipulation::BytesToFile::appendBytesToFileStream) +.def("write_file", &dataset_manipulation::BytesToFile::writeFile) +.def("close", &dataset_manipulation::BytesToFile::close); + +// callback wrappers +py::class_>(handle, "CallbackBase") + .def(py::init<>()) + .def("on_bar_end", &sampling::CallbackBase::on_bar_end) + .def("on_start", &sampling::CallbackBase::on_bar_end) + .def("on_prediction", &sampling::CallbackBase::on_prediction); + +py::class_>(handle, "LogLikelihoodCallback") + .def(py::init<>()) + .def_readwrite("loglik", &sampling::LogLikelihoodCallback::loglik) + .def_readwrite("sequence_length", &sampling::LogLikelihoodCallback::sequence_length); + +py::class_>(handle, "RecordTokenSequenceCallback") + .def(py::init<>()) + .def_readwrite("tokens", &sampling::RecordTokenSequenceCallback::tokens); + +py::class_>(handle, "TemperatureIncreaseCallback") + .def(py::init()) + .def_readwrite("current_temperature", &sampling::TemperatureIncreaseCallback::current_temperature); + +py::class_(handle, "CallbackManager") + .def(py::init<>()) + .def("add_callback", &sampling::CallbackManager::add_callback_ptr) + .def("on_bar_end", &sampling::CallbackManager::on_bar_end) + .def("on_prediction", &sampling::CallbackManager::on_prediction) + .def("on_start", &sampling::CallbackManager::on_start); + +} diff --git a/src/lib_encoder.h b/src/lib_encoder.h new file mode 100644 index 0000000000000000000000000000000000000000..b64684319c68ead4c27068e44a89c6484ed76862 --- /dev/null +++ b/src/lib_encoder.h @@ -0,0 +1,30 @@ +#include +namespace py = pybind11; + +void init_encoders(py::module &handle) { + + py::enum_(handle, "ENCODER_TYPE", py::arithmetic()) + .value("EXPRESSIVE_ENCODER", enums::ENCODER_TYPE::EXPRESSIVE_ENCODER) + .value("NO_ENCODER", enums::ENCODER_TYPE::NO_ENCODER) + .export_values(); + + py::class_(handle, "ExpressiveEncoder") + .def(py::init<>()) + .def("encode", &encoder::ExpressiveEncoder::encode) + .def("decode", &encoder::ExpressiveEncoder::decode) + .def("midi_to_json", &encoder::ExpressiveEncoder::midi_to_json) + .def("midi_to_tokens", &encoder::ExpressiveEncoder::midi_to_tokens) + .def("json_to_midi", &encoder::ExpressiveEncoder::json_to_midi) + .def("json_track_to_midi", &encoder::ExpressiveEncoder::json_track_to_midi) + .def("json_to_tokens", &encoder::ExpressiveEncoder::json_to_tokens) + .def("tokens_to_json", &encoder::ExpressiveEncoder::tokens_to_json) + .def("resample_delta_json", &encoder::ExpressiveEncoder::resample_delta_json) + .def("tokens_to_midi", &encoder::ExpressiveEncoder::tokens_to_midi) + .def("pretty", &encoder::ExpressiveEncoder::pretty) + .def("vocab_size", &encoder::ExpressiveEncoder::vocab_size) + .def("get_attribute_control_types", &encoder::ExpressiveEncoder::get_attribute_control_types) + .def("set_scheme", &encoder::ExpressiveEncoder::set_scheme) + .def_readonly("config", &encoder::ExpressiveEncoder::config) + .def_readonly("rep", &encoder::ExpressiveEncoder::rep); + +} diff --git a/src/trace.cpp b/src/trace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de5989cc78b09c70f319d339c986b97f282e513c --- /dev/null +++ b/src/trace.cpp @@ -0,0 +1,55 @@ +#define _GNU_SOURCE +#include + +#include +#include +#include + +static FILE *fp_trace; +int depth = 0; +static int MAX_DEPTH = 10; + +extern "C" { + +void __attribute__ ((constructor)) trace_begin (void) { + fp_trace = fopen("trace.out", "w"); + depth = 0; +} + +void __attribute__ ((destructor)) trace_end (void) { + if(fp_trace != NULL) { + fclose(fp_trace); + } +} + +void __cyg_profile_func_enter (void *func, void *caller) { + if (fp_trace != NULL) { + if (depth < MAX_DEPTH) { + Dl_info info; + if (dladdr(func, &info)) { + if (info.dli_sname) { + if (!strstr(info.dli_sname, "St3")) { + fprintf (fp_trace, "%i %p %p [%s] %s\n", + depth, + func, + caller, + info.dli_fname ? info.dli_fname : "?", + info.dli_sname ? info.dli_sname : "?"); + fflush(fp_trace); + } + } + } + } + depth++; + } +} + +void __cyg_profile_func_exit (void *func, void *caller) { + if(fp_trace != NULL) { + //fprintf(fp_trace, "x %p %p %lu\n", func, caller, time(NULL)); + depth--; + } + +} + +}