Analyze the model before pruning it to find a suitable pruning strategy to prune the model.
To run model analysis, provide a Python script containing the functions that
evaluate model performance. Assuming that your script is eval_model.py
, provide the required functions in one of three ways:
- A function named
model_fn()
that returns a Python dictionary of metric ops:def model_fn(): tf.logging.set_verbosity(tf.logging.INFO) img, labels = get_one_shot_test_data(TEST_BATCH) logits = net_fn(img, is_training=False) predictions = tf.argmax(logits, 1) labels = tf.argmax(labels, 1) eval_metric_ops = { 'accuracy': tf.metrics.accuracy(labels, predictions), 'recall_5': tf.metrics.recall_at_k(labels, logits, 5) } return eval_metric_ops
- A function named
model_fn()
that returns an instance oftf.estimator.Estimator
and a function namedeval_input_fn()
that feeds test data to the estimator:def model_fn(): return tf.estimator.Estimator( model_fn=cnn_model_fn, model_dir="./models/train/") def eval_input_fn(): return tf.estimator.inputs.numpy_input_fn( x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False)
- A function named
evaluate()
that takes a single parameter as argument that returns the metric score:def evaluate(checkpoint_path): with tf.Graph().as_default(): net = ConvNet(False) net.build(test_only=True) score = net.evaluate(checkpoint_path) return score
If you are using the tf.keras API, this is the recommended way:
import tensorflow as tf
def evaluate(checkpoint_path):
net = tf.keras.applications.ResNet50(weights=None,
include_top=True,
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000)
net.load_weights(checkpoint_path)
metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
loss = tf.keras.losses.SparseCategoricalCrossentropy()
# eval_data: validation dataset. You can refer to ‘tf.keras.Model.evaluate’ method to generate your validation dataset.
# EVAL_NUM: the number of validation dataset
res = net.evaluate(eval_data,
steps=EVAL_NUM/batch_size,
workers=16,
verbose=1)
eval_metric_ops = {'Recall_5': res[-1]}
return eval_metric_ops
If you use the first way to write the script, see the following snippet for
calling vai_p_tensorflow
to perform model analysis. The
results of the analysis is saved in a file named .ana in the workspace directory. Do
not move or delete this file because the pruner will load this file for pruning.
vai_p_tensorflow \
--action=ana \
--input_graph=inference_graph.pbtxt \
--input_ckpt=model.ckpt \
--eval_fn_path=eval_model.py \
--target="recall_5" \
--max_num_batches=500 \
--workspace:/tmp \
--exclude="conv node names that excluded from pruning" \
--output_nodes="output node names of the network"
Following are the arguments in this command. See vai_p_tensorflow Usage for a full list of options.
- --action
- The action to perform.
- --input_graph
- A GraphDef proto file that represents the inference graph of the network.
- --input_ckpt
- The path to a checkpoint to use for pruning.
- --eval_fn_path
- The path to a Python script defining an evaluation graph.
- --target
- The target score that evaluates the performance of the network. If there are more than one score in the network, choose the one that is most important.
- --max_num_batches
- The number of batches to run in the evaluation phase. This parameter affects the time taken to analyze the model. The larger this value, the more time required for the analysis and the more accurate the analysis is. The maximum value of this parameter is the size of the validation set or the batch_size, that is, all the data in the validation set is used for evaluation.
- --workspace
- Directory for saving output files.
- --exclude
- Convolution nodes excluded from pruning.
- --output_nodes
- Output nodes of the inference graph.