|
|
|
|
|
""" |
|
convert keras model files to frozen pb tensorflow weight file. The resultant TensorFlow model |
|
holds both the model architecture and its associated weights. |
|
""" |
|
import os, sys, argparse, logging |
|
from pathlib import Path |
|
import tensorflow as tf |
|
from tensorflow.python.framework import graph_util |
|
from tensorflow.python.framework import graph_io |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras.models import model_from_json, model_from_yaml, load_model |
|
|
|
|
|
if tf.__version__.startswith('2'): |
|
import tensorflow.compat.v1 as tf |
|
from tensorflow.compat.v1.keras import backend as K |
|
from tensorflow.compat.v1.keras.models import model_from_json, model_from_yaml, load_model |
|
tf.disable_eager_execution() |
|
|
|
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..')) |
|
from common.utils import get_custom_objects |
|
|
|
K.set_learning_phase(0) |
|
|
|
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" |
|
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) |
|
|
|
|
|
def load_input_model(input_model_path, input_json_path=None, input_yaml_path=None, custom_objects=None): |
|
if not Path(input_model_path).exists(): |
|
raise FileNotFoundError( |
|
'Model file `{}` does not exist.'.format(input_model_path)) |
|
try: |
|
model = load_model(input_model_path, custom_objects=custom_objects) |
|
return model |
|
except FileNotFoundError as err: |
|
logging.error('Input mode file (%s) does not exist.', input_model_path) |
|
raise err |
|
except ValueError as wrong_file_err: |
|
if input_json_path: |
|
if not Path(input_json_path).exists(): |
|
raise FileNotFoundError( |
|
'Model description json file `{}` does not exist.'.format( |
|
input_json_path)) |
|
try: |
|
model = model_from_json(open(str(input_json_path)).read()) |
|
model.load_weights(input_model_path) |
|
return model |
|
except Exception as err: |
|
logging.error("Couldn't load model from json.") |
|
raise err |
|
elif input_yaml_path: |
|
if not Path(input_yaml_path).exists(): |
|
raise FileNotFoundError( |
|
'Model description yaml file `{}` does not exist.'.format( |
|
input_yaml_path)) |
|
try: |
|
model = model_from_yaml(open(str(input_yaml_path)).read()) |
|
model.load_weights(input_model_path) |
|
return model |
|
except Exception as err: |
|
logging.error("Couldn't load model from yaml.") |
|
raise err |
|
else: |
|
logging.error( |
|
'Input file specified only holds the weights, and not ' |
|
'the model definition. Save the model using ' |
|
'model.save(filename.h5) which will contain the network ' |
|
'architecture as well as its weights. ' |
|
'If the model is saved using the ' |
|
'model.save_weights(filename) function, either ' |
|
'input_model_json or input_model_yaml flags should be set to ' |
|
'to import the network architecture prior to loading the ' |
|
'weights. \n' |
|
'Check the keras documentation for more details ' |
|
'(https://keras.io/getting-started/faq/)') |
|
raise wrong_file_err |
|
|
|
|
|
def keras_to_tensorflow(args): |
|
|
|
output_model = args.output_model |
|
if str(Path(output_model).parent) == '.': |
|
output_model = str((Path.cwd() / output_model)) |
|
|
|
output_fld = Path(output_model).parent |
|
output_model_name = Path(output_model).name |
|
output_model_stem = Path(output_model).stem |
|
output_model_pbtxt_name = output_model_stem + '.pbtxt' |
|
|
|
|
|
Path(output_model).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
if args.channels_first: |
|
K.set_image_data_format('channels_first') |
|
else: |
|
K.set_image_data_format('channels_last') |
|
|
|
custom_object_dict = get_custom_objects() |
|
|
|
model = load_input_model(args.input_model, args.input_model_json, args.input_model_yaml, custom_objects=custom_object_dict) |
|
|
|
|
|
orig_output_node_names = [node.name.split(':')[0] for node in model.outputs] |
|
|
|
if args.output_nodes_prefix: |
|
num_output = len(orig_output_node_names) |
|
pred = [None] * num_output |
|
converted_output_node_names = [None] * num_output |
|
|
|
|
|
for i in range(num_output): |
|
converted_output_node_names[i] = '{}{}'.format( |
|
args.output_nodes_prefix, i) |
|
pred[i] = tf.identity(model.outputs[i], |
|
name=converted_output_node_names[i]) |
|
else: |
|
converted_output_node_names = orig_output_node_names |
|
logging.info('Converted output node names are: %s', |
|
str(converted_output_node_names)) |
|
|
|
sess = K.get_session() |
|
if args.output_meta_ckpt: |
|
saver = tf.train.Saver() |
|
saver.save(sess, str(output_fld / output_model_stem)) |
|
|
|
if args.save_graph_def: |
|
tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld), |
|
output_model_pbtxt_name, as_text=True) |
|
logging.info('Saved the graph definition in ascii format at %s', |
|
str(Path(output_fld) / output_model_pbtxt_name)) |
|
|
|
if args.quantize: |
|
from tensorflow.tools.graph_transforms import TransformGraph |
|
transforms = ["quantize_weights", "quantize_nodes"] |
|
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], |
|
converted_output_node_names, |
|
transforms) |
|
constant_graph = graph_util.convert_variables_to_constants( |
|
sess, |
|
transformed_graph_def, |
|
converted_output_node_names) |
|
else: |
|
constant_graph = graph_util.convert_variables_to_constants( |
|
sess, |
|
sess.graph.as_graph_def(), |
|
converted_output_node_names) |
|
|
|
graph_io.write_graph(constant_graph, str(output_fld), output_model_name, |
|
as_text=False) |
|
logging.info('Saved the freezed graph at %s', |
|
str(Path(output_fld) / output_model_name)) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input_model', required=True, type=str, help='Path to the input model.') |
|
parser.add_argument('--input_model_json', required=False, type=str, help='Path to the input model architecture in json format.') |
|
parser.add_argument('--input_model_yaml', required=False, type=str, help='Path to the input model architecture in yaml format.') |
|
parser.add_argument('--output_model', required=True, type=str, help='Path where the converted model will be stored.') |
|
|
|
parser.add_argument('--save_graph_def', default=False, action="store_true", help='Whether to save the graphdef.pbtxt file which contains the graph definition in ASCII format. default=%(default)s') |
|
parser.add_argument('--output_nodes_prefix', required=False, type=str, help='If set, the output nodes will be renamed to `output_nodes_prefix`+i, where `i` will numerate the number of of output nodes of the network.') |
|
parser.add_argument('--quantize', default=False, action="store_true", help='If set, the resultant TensorFlow graph weights will be converted from float into eight-bit equivalents. See documentation here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms. default=%(default)s') |
|
parser.add_argument('--channels_first', default=False, action="store_true", help='Whether channels are the first dimension of a tensor. The default is TensorFlow behaviour where channels are the last dimension. default=%(default)s') |
|
parser.add_argument('--output_meta_ckpt', default=False, action="store_true", help='If set to True, exports the model as .meta, .index, and .data files, with a checkpoint file. These can be later loaded in TensorFlow to continue training. default=%(default)s') |
|
|
|
args = parser.parse_args() |
|
|
|
keras_to_tensorflow(args) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|