#pragma GCC diagnostic ignored "-Wunused-parameter" #include #include #include #include namespace py = pybind11; #include "common/encoder/encoder_all.h" #include "common/midi_parsing/midi_io.h" #include "./inference/dataset/jagged.h" #include "./inference/enum/model_type.h" #include "./inference/enum/encoder_types.h" #include "./inference/sampling/control.h" #include "./inference/sampling/callback_base.h" #include "./inference/version.h" #include "./common/midi_parsing/feature_extraction.h" #ifndef NO_TORCH #include "./inference/sampling/sample_internal.h" #include "./inference/sampling/multi_step_sample.h" #endif #include #include #include "../include/dataset_creation/dataset_manipulation/bytes_to_file.h" #include "../libraries/protobuf/include/proto_library.h" #include "../libraries/torch/include/torch_library.h" #include "../libraries/protobuf/build/midi.pb.h" #include "MidiFile.h" #include "./common/data_structures/train_config.h" #include "./lib_encoder.h" // ====================== namespace midigpt { // you can probably remove this namespace std::string generate_py(std::string &status_str, std::string &piece_str, std::string ¶m_str) { midi::Piece piece; google::protobuf::util::JsonStringToMessage(piece_str.c_str(), &piece); midi::Status status; google::protobuf::util::JsonStringToMessage(status_str.c_str(), &status); midi::HyperParam param; google::protobuf::util::JsonStringToMessage(param_str.c_str(), ¶m); #ifndef NO_TORCH sampling::sample(&piece, &status, ¶m, NULL); #endif std::string output_str; google::protobuf::util::MessageToJsonString(piece, &output_str); return output_str; } } // MAYBE THESE SHOULD GO IN A SEPARATE FILE FOR PYTHON WRAPPERS midi::Piece string_to_piece(std::string json_string) { midi::Piece x; google::protobuf::util::JsonStringToMessage(json_string.c_str(), &x); return x; } std::string piece_to_string(midi::Piece x) { std::string json_string; google::protobuf::util::MessageToJsonString(x, &json_string); return json_string; } std::string select_random_segment_py(std::string json_string, int num_bars, int min_tracks, int max_tracks, int seed) { std::mt19937 engine(seed); midi::Piece x; util_protobuf::string_to_protobuf(json_string, &x); util_protobuf::select_random_segment(&x, num_bars, min_tracks, max_tracks, &engine); return util_protobuf::protobuf_to_string(&x); } // MAYBE THESE SHOULD GO IN A SEPARATE FILE FOR PYTHON WRAPPERS py::bytes midi_to_json_bytes(std::string &filepath, data_structures::TrainConfig *tc, std::string &metadata_labels) { std::string x; midi::Piece p; auto config = std::make_shared(); config->resolution = tc->resolution; config->decode_resolution = tc->decode_resolution; config->delta_resolution = tc->delta_resolution; config->use_microtiming = tc->use_microtiming; midi_io::ParseSong(filepath, &p, config); util_protobuf::UpdateValidSegments(&p, tc->num_bars, tc->min_tracks); if (!p.internal_valid_segments_size()) { return py::bytes(x); // empty bytes } // insert metadata labels here midi::MetadataLabels *ml = new midi::MetadataLabels(); google::protobuf::util::JsonStringToMessage(metadata_labels, ml); p.set_allocated_internal_metadata_labels(ml); p.SerializeToString(&x); return py::bytes(x); } std::string json_bytes_to_string(py::bytes &json_bytes) { midi::Piece p; p.ParseFromString(json_bytes); return util_protobuf::protobuf_to_string(&p); } PYBIND11_MODULE(midigpt,handle) { handle.def("select_random_segment", &select_random_segment_py); handle.def("status_from_piece", &util_protobuf::status_from_piece_py); handle.def("default_sample_param", &util_protobuf::default_sample_param_py); handle.def("prune_tracks", &util_protobuf::prune_tracks_py); handle.def("version", &version); handle.def("getEncoderSize", &enums::getEncoderSize); handle.def("getEncoderType", &enums::getEncoderType); handle.def("getEncoder", &enums::getEncoder); handle.def("getEncoderTypeList", &enums::getEncoderTypeList); handle.def("getAttributeControlStr", &encoder::getAttributeControlStr); #ifndef NO_TORCH handle.def("sample_multi_step", &sampling::sample_multi_step_py); handle.def("sample_multi_step_capture_output", [](std::string piece_json, std::string status_json, std::string param_json, int max_attempts, sampling::CallbackManager *callbacks) { py::scoped_ostream_redirect stream( std::cout, py::module_::import("sys").attr("stdout") // Python output ); return sampling::sample_multi_step_py(piece_json, status_json, param_json, max_attempts, callbacks); }); handle.def("get_notes", &sampling::get_notes_py); #endif handle.def("compute_all_attribute_controls", &encoder::compute_all_attribute_controls_py); handle.def("get_instruments_by_category", &enums::get_instruments_by_category); handle.def("get_instrument_and_track_type_from_gm_inst", &enums::get_instrument_and_track_type_from_gm_inst); handle.def("midi_to_json_bytes", &midi_to_json_bytes); handle.def("json_bytes_to_string", &json_bytes_to_string); py::enum_(handle, "MODEL_TYPE", py::arithmetic()) .value("TRACK_MODEL", enums::MODEL_TYPE::TRACK_MODEL) .value("BAR_INFILL_MODEL", enums::MODEL_TYPE::BAR_INFILL_MODEL) .export_values(); py::class_(handle, "Jagged") .def(py::init()) .def("set_seed", &compression::Jagged::set_seed) .def("set_num_bars", &compression::Jagged::set_num_bars) .def("set_min_tracks", &compression::Jagged::set_min_tracks) .def("set_max_tracks", &compression::Jagged::set_max_tracks) .def("set_max_seq_len", &compression::Jagged::set_max_seq_len) .def("enable_write", &compression::Jagged::enable_write) .def("enable_read", &compression::Jagged::enable_read) .def("append", &compression::Jagged::append) .def("read", &compression::Jagged::read) .def("read_bytes", &compression::Jagged::read_bytes) .def("read_json", &compression::Jagged::read_json) .def("read_batch", &compression::Jagged::read_batch) .def("load_random_piece", &compression::Jagged::load_random_piece_py) .def("load_piece", &compression::Jagged::load_piece) .def("close", &compression::Jagged::close) .def("get_size", &compression::Jagged::get_size) .def("get_split_size", &compression::Jagged::get_split_size); py::class_(handle, "TrainConfig") .def(py::init<>()) .def_readwrite("num_bars", &data_structures::TrainConfig::num_bars) .def_readwrite("min_tracks", &data_structures::TrainConfig::min_tracks) .def_readwrite("max_tracks", &data_structures::TrainConfig::max_tracks) .def_readwrite("max_mask_percentage", &data_structures::TrainConfig::max_mask_percentage) .def_readwrite("no_max_length", &data_structures::TrainConfig::no_max_length) .def_readwrite("resolution", &data_structures::TrainConfig::resolution) .def_readwrite("use_microtiming", &data_structures::TrainConfig::use_microtiming) .def_readwrite("microtiming", &data_structures::TrainConfig::microtiming) .def_readwrite("decode_resolution", &data_structures::TrainConfig::decode_resolution) .def_readwrite("delta_resolution", &data_structures::TrainConfig::delta_resolution) .def("to_json", &data_structures::TrainConfig::ToJson) .def("from_json", &data_structures::TrainConfig::FromJson); py::class_>(handle, "REPRESENTATION") .def(py::init>>()) .def("decode", &encoder::REPRESENTATION::decode) .def("is_token_type", &encoder::REPRESENTATION::is_token_type) .def("in_domain", &encoder::REPRESENTATION::in_domain) .def("encode", &encoder::REPRESENTATION::encode) .def("encode_partial", &encoder::REPRESENTATION::encode_partial_py_int) .def("encode_to_one_hot", &encoder::REPRESENTATION::encode_to_one_hot) .def("pretty", &encoder::REPRESENTATION::pretty) .def_readonly("vocab_size", &encoder::REPRESENTATION::vocab_size) .def("get_type_mask", &encoder::REPRESENTATION::get_type_mask) .def("max_token", &encoder::REPRESENTATION::max_token) .def_readonly("token_domains", &encoder::REPRESENTATION::token_domains); py::class_(handle, "TOKEN_DOMAIN") .def(py::init()); py::class_>(handle, "EncoderConfig") .def(py::init<>()) .def("ToJson", &data_structures::EncoderConfig::ToJson) .def("FromJson", &data_structures::EncoderConfig::FromJson) .def_readwrite("both_in_one", &data_structures::EncoderConfig::both_in_one) .def_readwrite("unquantized", &data_structures::EncoderConfig::unquantized) .def_readwrite("do_multi_fill", &data_structures::EncoderConfig::do_multi_fill) .def_readwrite("use_velocity_levels", &data_structures::EncoderConfig::use_velocity_levels) .def_readwrite("use_microtiming", &data_structures::EncoderConfig::use_microtiming) .def_readwrite("transpose", &data_structures::EncoderConfig::transpose) .def_readwrite("resolution", &data_structures::EncoderConfig::resolution) .def_readwrite("decode_resolution", &data_structures::EncoderConfig::decode_resolution) .def_readwrite("decode_final", &data_structures::EncoderConfig::decode_final) .def_readwrite("delta_resolution", &data_structures::EncoderConfig::delta_resolution) .def_readwrite("multi_fill", &data_structures::EncoderConfig::multi_fill); py::enum_(handle, "TOKEN_TYPE", py::arithmetic()) .value("PIECE_START", midi::TOKEN_PIECE_START) .value("NOTE_ONSET", midi::TOKEN_NOTE_ONSET) .value("PITCH", midi::TOKEN_PITCH) .value("VELOCITY", midi::TOKEN_VELOCITY) .value("DELTA", midi::TOKEN_DELTA) .value("DELTA_DIRECTION", midi::TOKEN_DELTA_DIRECTION) .value("TIME_ABSOLUTE_POS", midi::TOKEN_TIME_ABSOLUTE_POS) .value("INSTRUMENT", midi::TOKEN_INSTRUMENT) .value("BAR", midi::TOKEN_BAR) .value("BAR_END", midi::TOKEN_BAR_END) .value("TRACK", midi::TOKEN_TRACK) .value("TRACK_END", midi::TOKEN_TRACK_END) .value("DRUM_TRACK", midi::TOKEN_DRUM_TRACK) .value("FILL_IN", midi::TOKEN_FILL_IN) .value("FILL_IN_PLACEHOLDER", midi::TOKEN_FILL_IN_PLACEHOLDER) .value("FILL_IN_START", midi::TOKEN_FILL_IN_START) .value("FILL_IN_END", midi::TOKEN_FILL_IN_END) .value("VELOCITY_LEVEL", midi::TOKEN_VELOCITY_LEVEL) .value("GENRE", midi::TOKEN_GENRE) .value("DENSITY_LEVEL", midi::TOKEN_DENSITY_LEVEL) .value("TIME_SIGNATURE", midi::TOKEN_TIME_SIGNATURE) .value("NOTE_DURATION", midi::TOKEN_NOTE_DURATION) .value("AV_POLYPHONY", midi::TOKEN_AV_POLYPHONY) .value("MIN_POLYPHONY", midi::TOKEN_MIN_POLYPHONY) .value("MAX_POLYPHONY", midi::TOKEN_MAX_POLYPHONY) .value("MIN_NOTE_DURATION", midi::TOKEN_MIN_NOTE_DURATION) .value("MAX_NOTE_DURATION", midi::TOKEN_MAX_NOTE_DURATION) .value("NUM_BARS", midi::TOKEN_NUM_BARS) .value("MIN_POLYPHONY_HARD", midi::TOKEN_MIN_POLYPHONY_HARD) .value("MAX_POLYPHONY_HARD", midi::TOKEN_MAX_POLYPHONY_HARD) .value("MIN_NOTE_DURATION_HARD", midi::TOKEN_MIN_NOTE_DURATION_HARD) .value("MAX_NOTE_DURATION_HARD", midi::TOKEN_MAX_NOTE_DURATION_HARD) .value("NONE", midi::TOKEN_NONE) .export_values(); // ========================================================= // ========================================================= // ENCODERS // ========================================================= // ========================================================= init_encoders(handle); // // ========================================================= // ========================================================= // DATASET CREATION // ========================================================= // ========================================================= //dataset_manipulation folder definitions py::class_(handle, "BytesToFile") .def(py::init()) .def("append_bytes_to_file_stream", &dataset_manipulation::BytesToFile::appendBytesToFileStream) .def("write_file", &dataset_manipulation::BytesToFile::writeFile) .def("close", &dataset_manipulation::BytesToFile::close); // callback wrappers py::class_>(handle, "CallbackBase") .def(py::init<>()) .def("on_bar_end", &sampling::CallbackBase::on_bar_end) .def("on_start", &sampling::CallbackBase::on_bar_end) .def("on_prediction", &sampling::CallbackBase::on_prediction); py::class_>(handle, "LogLikelihoodCallback") .def(py::init<>()) .def_readwrite("loglik", &sampling::LogLikelihoodCallback::loglik) .def_readwrite("sequence_length", &sampling::LogLikelihoodCallback::sequence_length); py::class_>(handle, "RecordTokenSequenceCallback") .def(py::init<>()) .def_readwrite("tokens", &sampling::RecordTokenSequenceCallback::tokens); py::class_>(handle, "TemperatureIncreaseCallback") .def(py::init()) .def_readwrite("current_temperature", &sampling::TemperatureIncreaseCallback::current_temperature); py::class_(handle, "CallbackManager") .def(py::init<>()) .def("add_callback", &sampling::CallbackManager::add_callback_ptr) .def("on_bar_end", &sampling::CallbackManager::on_bar_end) .def("on_prediction", &sampling::CallbackManager::on_prediction) .def("on_start", &sampling::CallbackManager::on_start); }