Analyze the model before pruning to find a suitable pruning strategy.
To run model analysis, provide a Python script containing the functions that
evaluate model performance. Assuming that your script is
, provide the required functions in one of three methods:
- A function named
that returns a Python dictionary of measurements:def model_fn(): tf.logging.set_verbosity(tf.logging.INFO) img, labels = get_one_shot_test_data(TEST_BATCH) logits = net_fn(img, is_training=False) predictions = tf.argmax(logits, 1) labels = tf.argmax(labels, 1) eval_metric_ops = { 'accuracy': tf.metrics.accuracy(labels, predictions), 'recall_5': tf.metrics.recall_at_k(labels, logits, 5) } return eval_metric_ops
- A function named
that returns an instance oftf.estimator.Estimator
and a function namedeval_input_fn()
that feeds test data to the estimator:def model_fn(): return tf.estimator.Estimator( model_fn=cnn_model_fn, model_dir="./models/train/") def eval_input_fn(): return tf.estimator.inputs.numpy_input_fn( x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False)
- A function named
that takes a single parameter as an argument and returns the metric score:def evaluate(checkpoint_path): with tf.Graph().as_default(): net = ConvNet(False) score = net.evaluate(checkpoint_path) return score
If you are using the tf.keras API, this is the recommended method:
import tensorflow as tf
def evaluate(checkpoint_path):
net = tf.keras.applications.ResNet50(weights=None,
metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
loss = tf.keras.losses.SparseCategoricalCrossentropy()
# eval_data: validation dataset. You can refer to ‘tf.keras.Model.evaluate’ method to generate your validation dataset.
# EVAL_NUM: the number of validation dataset
res = net.evaluate(eval_data,
eval_metric_ops = {'Recall_5': res[-1]}
return eval_metric_ops
If you use the first method to write the script, see the following snippet for
calling vai_p_tensorflow
to perform model analysis. The
results of the analysis is saved in a file named .ana in the workspace directory. Do
not move or delete this file because the pruner will load this file for pruning.
vai_p_tensorflow \
--action=ana \
--input_graph=inference_graph.pbtxt \
--input_ckpt=model.ckpt \ \
--target="recall_5" \
--max_num_batches=500 \
--workspace:/tmp \
--exclude="conv node names that should be excluded from pruning" \
--output_nodes="output node names of the network"
Following are the arguments in this command. See vai_p_tensorflow Usage for a full list of options.
- --action
- The action to perform.
- --input_graph
- A GraphDef proto file that represents the inference graph of the network.
- --input_ckpt
- The path to a checkpoint to use for pruning.
- --eval_fn_path
- The path to a Python script defining an evaluation graph.
- --target
- The target metric that evaluates the performance of the network. If there is more than one score in the network, choose the one that is the most important.
- --max_num_batches
- The number of batches to run in the evaluation phase. This parameter affects the time taken to analyze the model. A larger value will result in longer analysis time but more accurate results. The maximum value of this parameter is the size of the validation set or the batch_size, that is, all the data in the validation set would be used for evaluation.
- --workspace
- Directory for saving output files.
- --exclude
- Convolution nodes excluded from pruning.
- --output_nodes
- Output nodes of the inference graph.