This example demonstrates how to prune a Keras model. A pre-defined ResNet50 is used here.
- Prepare evaluation script for model analysis, named ResNet50_model.py.
from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf import time from preprocessing.dataset import input_fn, NUM_IMAGES TRAIN_NUM = NUM_IMAGES['train'] EVAL_NUM = NUM_IMAGES['validation'] DATASET_DIR="/scratch/workspace/dataset/imagenet/tf_records" batch_size = 100 image_size = 224 def get_input_data(prefix_preprocessing="vgg"): eval_data = input_fn( is_training=False, data_dir=DATASET_DIR, output_width=image_size, output_height=image_size, batch_size=batch_size, num_epochs=1, num_gpus=1, dtype=tf.float32, prefix_preprocessing=prefix_preprocessing) return eval_data network_fn = tf.keras.applications.ResNet50(weights=None, include_top=True, input_tensor=None, input_shape=None, pooling=None, classes=1000) def evaluate(ckpt_path=''): network_fn.load_weights(ckpt_path) metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy() accuracy = tf.keras.metrics.SparseCategoricalAccuracy() loss = tf.keras.losses.SparseCategoricalCrossentropy() network_fn.compile(loss=loss, metrics=[accuracy, metric_top_5]) # eval_data: validation dataset. You can refer to ‘tf.keras.model.evaluate’ method to find out eval_data format and write data processing function to get your evaluation dataset. eval_data = get_input_data() res = network_fn.evaluate(eval_data, steps=EVAL_NUM/batch_size, workers=16, verbose=1) delta_time = time.time() - start_time rescall5 = res[-1] eval_metric_ops = {'Recall_5': rescall5} return eval_metric_ops
- Export inference
graph.
import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.python.framework import graph_util tf.keras.backend.set_learning_phase(0) model = tf.keras.applications.ResNet50(weights=None, include_top=True, input_tensor=None, input_shape=None, pooling=None, classes=1000) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy()) graph_def = K.get_session().graph.as_graph_def() graph_def = graph_util.extract_sub_graph(graph_def, ["probs/Softmax"]) tf.train.write_graph(graph_def, "./models/ResNet50/train", "ResNet50_inf_graph.pbtxt", as_text=True)
- Convert weights from HDF5 to TensorFlow format.Note: Skip this step if the weights are already in the TensorFlow format.
import tensorflow as tf tf.keras.backend.set_learning_phase(0) model = tf.keras.applications.ResNet50(weights="imagenet", include_top=True, input_tensor=None, input_shape=None, pooling=None, classes=1000) model.save_weights("./models/ResNet50/train/ResNet50.ckpt", save_format='tf')
- Run model
analysis.
vai_p_tensorflow \ --action=ana \ --input_graph=./models/ResNet50/train/ResNet50_inf_graph.pbtxt \ --input_ckpt=./models/ResNet50/train/ResNet50.ckpt \ --eval_fn_path=./ResNet50_model.py \ --target=top-5 \ --workspace=./ \ --input_nodes="input_1" \ --input_node_shapes="1,224,224,3" \ --exclude="" \ --output_nodes="probs/Softmax"
- Run model
pruning.
vai_p_tensorflow \ --action=prune \ --input_graph=./models/ResNet50/train/ResNet50_inf_graph.pbtxt \ --input_ckpt=./models/ResNet50/train/ResNet50.ckpt \ --output_graph=./models/ResNet50/pruned/graph.pbtxt \ --output_ckpt=./models/ResNet50/pruned/sparse.ckpt \ --workspace=./ \ --input_nodes="input_1" \ --input_node_shapes="1,224,224,3" \ --exclude="" \ --sparsity=0.5 \ --output_nodes="probs/Softmax"
- Prepare model training code "train.py".
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os, time import tensorflow as tf import numpy as np from preprocessing import preprocessing_factory from preprocessing.dataset import input_fn, NUM_IMAGES TRAIN_NUM = NUM_IMAGES['train'] EVAL_NUM = NUM_IMAGES['validation'] tf.flags.DEFINE_string('model_name', 'ResNet50', 'The keras model name.') tf.flags.DEFINE_boolean('pruning', True, 'If running with pruning masks.') tf.flags.DEFINE_string('data_dir', '', 'The directory where put the evaluation tfrecord data.') tf.flags.DEFINE_string('checkpoint_path', './models/ResNet50/pruned/sparse.ckpt ', 'Model weights path from which to fine-tune.') tf.flags.DEFINE_string('train_dir', './models/ResNet50/pruned/ft', 'The directory where save model') tf.flags.DEFINE_string('ckpt_filename', "trained_model_{epoch}.ckpt", 'Model filename to be saved.') tf.flags.DEFINE_string('ft_ckpt', '', 'The model path to be saved from last epoch.') tf.flags.DEFINE_integer('batch_size', 100, 'Train batch size.') tf.flags.DEFINE_integer('train_image_size', 224, 'Train image size.') tf.flags.DEFINE_integer('epoches', 1, 'Train epochs') tf.flags.DEFINE_integer('eval_every_epoch', 1, '') tf.flags.DEFINE_integer('steps_per_epoch', None, 'How many steps one epoch contains.') tf.flags.DEFINE_float('learning_rate', 5e-3, 'Learning rate.') FLAGS = tf.flags.FLAGS def get_input_data(num_epochs=1, prefix_preprocessing="vgg"): train_data = input_fn( is_training=True, data_dir=FLAGS.data_dir, output_width=FLAGS.train_image_size, output_height=FLAGS.train_image_size, batch_size=FLAGS.batch_size, num_epochs=num_epochs, num_gpus=1, dtype=tf.float32, prefix_preprocessing=prefix_preprocessing) eval_data = input_fn( is_training=False, data_dir=FLAGS.data_dir, output_width=FLAGS.train_image_size, output_height=FLAGS.train_image_size, batch_size=FLAGS.batch_size, num_epochs=1, num_gpus=1, dtype=tf.float32, prefix_preprocessing=prefix_preprocessing) return train_data, eval_data tf.logging.info('Fine-tuning from %s' % FLAGS.checkpoint_path) tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.pruning: tf.set_pruning_mode() module_name = 'tf.keras.applications.' + FLAGS.model_name model = eval(module_name)(weights=None, include_top=True, input_tensor=None, input_shape=None, pooling=None, classes=1000) os.makedirs(FLAGS.train_dir, exist_ok=True) def main(): config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.5 config.gpu_options.allow_growth = True prefix_preprocessing = preprocessing_factory.get_preprocessing_method(FLAGS.model_name) train_data, eval_data = get_input_data(num_epochs=FLAGS.epoches+1, prefix_preprocessing=prefix_preprocessing) callbacks = [ tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(FLAGS.train_dir, FLAGS.ckpt_filename), save_best_only=True, save_weights_only=True, monitor="sparse_categorical_accuracy", verbose=1, ) ] opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy() accuracy = tf.keras.metrics.SparseCategoricalAccuracy() loss = tf.keras.losses.SparseCategoricalCrossentropy() model.compile(loss=loss, metrics=[accuracy, metric_top_5], optimizer=opt) model.load_weights(FLAGS.checkpoint_path) start = time.time() steps_per_epoch = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else np.ceil(TRAIN_NUM/FLAGS.batch_size) model.fit(train_data, epochs=FLAGS.epoches, callbacks=callbacks, steps_per_epoch=steps_per_epoch, # max_queue_size=16, workers=16) t_delta = round(1000*(time.time()-start), 2) print("Training {} epoch needs {}ms".format(FLAGS.epoches, t_delta)) model.save_weights(FLAGS.ft_ckpt, save_format='tf') print('Finished training!') if __name__ == "__main__": main()
- Run model training code for fine-tuning the pruned
model.
python train.py –-pruning=True --checkpoint_path=./models/ResNet50/pruned/sparse.ckpt
- Transform sparse model to dense
model.
vai_p_tensorflow \ --action=transform \ --input_ckpt=./models/ResNet50/ft/trained_model_epoch.ckpt \ --output_ckpt=./models/ ResNet50/pruned/transformed.ckpt
- Freeze
graph.
from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as saver_lib def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.replace( " ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def def _parse_input_graph_proto(input_graph, input_binary): """Parser input tensorflow graph into GraphDef proto.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) return input_graph_def def _parse_input_meta_graph_proto(input_graph, input_binary): """Parser input tensorflow graph into MetaGraphDef proto.""" if not gfile.Exists(input_graph): print("Input meta graph file '" + input_graph + "' does not exist!") return -1 input_meta_graph_def = MetaGraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_meta_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_meta_graph_def) print("Loaded meta graph file '" + input_graph) return input_meta_graph_def def _parse_input_saver_proto(input_saver, input_binary): """Parser input tensorflow Saver into SaverDef proto.""" if not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 mode = "rb" if input_binary else "r" with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) return saver_def def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=None, saved_model_tags=tag_constants.SERVING, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" input_graph_def = None if input_saved_model_dir: input_graph_def = saved_model_utils.get_meta_graph_def( input_saved_model_dir, saved_model_tags).graph_def elif input_graph: input_graph_def = _parse_input_graph_proto(input_graph, input_binary) input_meta_graph_def = None if input_meta_graph: input_meta_graph_def = _parse_input_meta_graph_proto( input_meta_graph, input_binary) input_saver_def = None if input_saver: input_saver_def = _parse_input_saver_proto(input_saver, input_binary) freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph_def, input_saved_model_dir, saved_model_tags.replace(" ", "").split(","), checkpoint_version=checkpoint_version) def main(unused_args, flags): if flags.checkpoint_version == 1: checkpoint_version = saver_pb2.SaverDef.V1 elif flags.checkpoint_version == 2: checkpoint_version = saver_pb2.SaverDef.V2 else: print("Invalid checkpoint version (must be '1' or '2'): %d" % flags.checkpoint_version) return -1 freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, flags.input_checkpoint, flags.output_node_names, flags.restore_op_name, flags.filename_tensor_name, flags.output_graph, flags.clear_devices, flags.initializer_nodes, flags.variable_names_whitelist, flags.variable_names_blacklist, flags.input_meta_graph, flags.input_saved_model_dir, flags.saved_model_tags, checkpoint_version) def run_main(): parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument("--input_graph", type=str, default="./models/ ResNet50/pruned/graph.pbtxt", help="TensorFlow \'GraphDef\' file to load.") parser.add_argument("--input_saver", type=str, default="", help="TensorFlow saver file to load.") parser.add_argument("--input_checkpoint", type=str, default="./models/ ResNet50/pruned/transformed.ckpt", help="TensorFlow variables file to load.") parser.add_argument("--checkpoint_version", type=int, default=2, help="Tensorflow variable file format") parser.add_argument("--output_graph", type=str, default="./models/ ResNet50/pruned/frozen_ResNet50.pb", help="Output \'GraphDef\' file name.") parser.add_argument("--input_binary", nargs="", const=True, type="bool", default=False, help="Whether the input files are in binary format.") parser.add_argument("--output_node_names", type=str, default="probs/Softmax", help="The name of the output nodes, comma separated.") parser.add_argument("--restore_op_name", type=str, default="save/restore_all", help="""\ The name of the master restore operator. Deprecated, unused by updated \ loading code. """) parser.add_argument("--filename_tensor_name", type=str, default="save/Const:0", help="""\ The name of the tensor holding the save path. Deprecated, unused by \ updated loading code. """) parser.add_argument("--clear_devices", nargs="", const=True, type="bool", default=True, help="Whether to remove device specifications.") parser.add_argument( "--initializer_nodes", type=str, default="", help="Comma separated list of initializer nodes to run before freezing.") parser.add_argument("--variable_names_whitelist", type=str, default="", help="""\ Comma separated list of variables to convert to constants. If specified, \ only those variables will be converted to constants.\ """) parser.add_argument("--variable_names_blacklist", type=str, default="", help="""\ Comma separated list of variables to skip converting to constants.\ """) parser.add_argument("--input_meta_graph", type=str, default="", help="TensorFlow \'MetaGraphDef\' file to load.") parser.add_argument( "--input_saved_model_dir", type=str, default="", help="Path to the dir with TensorFlow \'SavedModel\' file and variables.") parser.add_argument("--saved_model_tags", type=str, default="serve", help="""\ Group of tag(s) of the MetaGraphDef to load, in string format,\ separated by \',\'. For tag-set contains multiple tags, all tags \ must be passed in.\ """) flags, unparsed = parser.parse_known_args() my_main = lambda unused_args: main(unused_args, flags) app.run(main=my_main, argv=[sys.argv[0]] + unparsed) if __name__ == '__main__': run_main()