TensorFlow Model
First, build a TensorFlow graph for training and evaluation. Each part must be written in a separate script. If you have trained a baseline model before and you have the training code, then you only need to prepare the code for evaluation.
The evaluation script must contain a function named model_fn
that creates all the needed nodes from input
to output. The function should return a dictionary that maps the names of output
nodes to their operations or a tf.estimator.Estimator
. For example, if your network is an image
classifier, the returned dictionary usually includes operations to calculate top-1
and top-5 accuracy as shown in the following snippet:
def model_fn():
# graph definition
# ……
return {
'top-1': slim.metrics.streaming_accuracy(predictions, labels),
'top-5': slim.metrics.streaming_recall_at_k(logits, org_labels, 5)
}
Or, if you use the TensorFlow Estimator API to train and evaluate your
network, your model_fn
function must return an
instance of the tf.estimator
. At the same time, you
must also provide a function called eval_input_fn
,
which the Estimator uses to get the data used in the evaluation.
def cnn_model_fn(features, labels, mode):
# assemble the graph here
…
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
def model_fn():
return tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="./models/train/")
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
def eval_input_fn():
return tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)
You must also include code that can be used to export an inference GraphDef file and evaluate network performance during pruning.
To export a GraphDef proto file, use the following code:
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
with tf.Graph().as_default() as graph:
# your graph definition here
# ……
graph_def = graph.as_graph_def()
with gfile.GFile(‘inference_graph.pbtxt’, 'w') as f:
f.write(text_format.MessageToString(graph_def))
Keras Model
For the Keras model, there is no explicit graph definition. Get a
GraphDef object first and then export it.
An example of tf.keras
predefined ResNet50 is given
here:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.framework import graph_util
tf.keras.backend.set_learning_phase(0)
model = tf.keras.applications.ResNet50(weights=None,
include_top=True,
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())
graph_def = K.get_session().graph.as_graph_def()
# "probs/Softmax": Output node of ResNet50 graph.
graph_def = graph_util.extract_sub_graph(graph_def, ["probs/Softmax"])
tf.train.write_graph(graph_def,
"./",
"inference_graph.pbtxt",
as_text=True)