Performing Model Analysis - 3.0 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2023-01-12
Version
3.0 English

Analyze the model before pruning to find a suitable pruning strategy.

To run model analysis, provide a Python script containing the functions that evaluate model performance. Assuming that your script is eval_model.py, provide the required functions in one of three methods:

  • A function named model_fn() that returns a Python dictionary of measurements:
    def model_fn():
      tf.logging.set_verbosity(tf.logging.INFO)
      img, labels = get_one_shot_test_data(TEST_BATCH)
    
      logits = net_fn(img, is_training=False)
      predictions = tf.argmax(logits, 1)
      labels = tf.argmax(labels, 1)
      eval_metric_ops = {
          'accuracy': tf.metrics.accuracy(labels, predictions),
          'recall_5': tf.metrics.recall_at_k(labels, logits, 5)
      }
      return eval_metric_ops
    
  • A function named model_fn() that returns an instance of tf.estimator.Estimator and a function named eval_input_fn() that feeds test data to the estimator:
    def model_fn():
      return tf.estimator.Estimator(
          model_fn=cnn_model_fn, model_dir="./models/train/")
    
    def eval_input_fn():
      return tf.estimator.inputs.numpy_input_fn(
          x={"x": eval_data},
          y=eval_labels,
          num_epochs=1,
          shuffle=False)
    
  • A function named evaluate() that takes a single parameter as an argument and returns the metric score:
    def evaluate(checkpoint_path):
      with tf.Graph().as_default():
        net = ConvNet(False)
        net.build(test_only=True)
        score = net.evaluate(checkpoint_path)
        return score
    

If you are using the tf.keras API, this is the recommended method:

import tensorflow as tf

def evaluate(checkpoint_path):
net = tf.keras.applications.ResNet50(weights=None,
include_top=True,
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000)
net.load_weights(checkpoint_path)
metric_top_5 = tf.keras.metrics.SparseTopKCategoricalAccuracy()
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
loss = tf.keras.losses.SparseCategoricalCrossentropy()

# eval_data: validation dataset. You can refer to ‘tf.keras.Model.evaluate’ method to generate your validation dataset. 
# EVAL_NUM: the number of validation dataset
res = net.evaluate(eval_data,
steps=EVAL_NUM/batch_size,
workers=16,
verbose=1)
eval_metric_ops = {'Recall_5': res[-1]}
return eval_metric_ops

If you use the first method to write the script, see the following snippet for calling vai_p_tensorflow to perform model analysis. The results of the analysis is saved in a file named .ana in the workspace directory. Do not move or delete this file because the pruner will load this file for pruning.

vai_p_tensorflow \
  --action=ana \
  --input_graph=inference_graph.pbtxt \
  --input_ckpt=model.ckpt \
  --eval_fn_path=eval_model.py \
  --target="recall_5" \
  --max_num_batches=500 \
  --workspace:/tmp \
  --exclude="conv node names that should be excluded from pruning" \
  --output_nodes="output node names of the network"

Following are the arguments in this command. See vai_p_tensorflow Usage for a full list of options.

--action
The action to perform.
--input_graph
A GraphDef proto file that represents the inference graph of the network.
--input_ckpt
The path to a checkpoint to use for pruning.
--eval_fn_path
The path to a Python script defining an evaluation graph.
--target
The target metric that evaluates the performance of the network. If there is more than one score in the network, choose the one that is the most important.
--max_num_batches
The number of batches to run in the evaluation phase. This parameter affects the time taken to analyze the model. A larger value will result in longer analysis time but more accurate results. The maximum value of this parameter is the size of the validation set or the batch_size, that is, all the data in the validation set would be used for evaluation.
--workspace
Directory for saving output files.
--exclude
Convolution nodes excluded from pruning.
--output_nodes
Output nodes of the inference graph.