|
|
|
|
|
"""
|
|
Script to convert Deeplab keras model to ONNX model
|
|
"""
|
|
import os, sys, argparse
|
|
import shutil
|
|
import subprocess
|
|
import tensorflow as tf
|
|
from tensorflow.keras.models import load_model
|
|
import keras2onnx
|
|
import onnx
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..'))
|
|
from common.utils import get_custom_objects
|
|
|
|
os.environ['TF_KERAS'] = '1'
|
|
|
|
def onnx_convert(keras_model_file, output_file, op_set):
|
|
custom_object_dict = get_custom_objects()
|
|
model = load_model(keras_model_file, custom_objects=custom_object_dict)
|
|
|
|
|
|
onnx_model = keras2onnx.convert_keras(model, model.name, custom_op_conversions=custom_object_dict, target_opset=op_set)
|
|
|
|
|
|
onnx.save_model(onnx_model, output_file)
|
|
|
|
|
|
def onnx_convert_with_savedmodel(keras_model_file, output_file, op_set):
|
|
|
|
if not tf.__version__.startswith('2'):
|
|
raise ValueError('savedmodel convert only support in TF 2.x env')
|
|
|
|
custom_object_dict = get_custom_objects()
|
|
model = load_model(keras_model_file, custom_objects=custom_object_dict)
|
|
|
|
|
|
model.save('tmp_savedmodel', save_format='tf')
|
|
|
|
|
|
cmd = 'python -m tf2onnx.convert --saved-model tmp_savedmodel --output {} --opset {}'.format(output_file, op_set)
|
|
process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE)
|
|
output, error = process.communicate()
|
|
|
|
|
|
shutil.rmtree('tmp_savedmodel')
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description='Convert Deeplab tf.keras model to ONNX model')
|
|
parser.add_argument('--keras_model_file', required=True, type=str, help='path to keras model file')
|
|
parser.add_argument('--output_file', required=True, type=str, help='output onnx model file')
|
|
parser.add_argument('--op_set', required=False, type=int, help='onnx op set, default=%(default)s', default=10)
|
|
parser.add_argument('--with_savedmodel', default=False, action="store_true", help='use a temp savedmodel for convert')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.with_savedmodel:
|
|
onnx_convert_with_savedmodel(args.keras_model_file, args.output_file, args.op_set)
|
|
else:
|
|
onnx_convert(args.keras_model_file, args.output_file, args.op_set)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|