Create a file named ft.py and add the following code:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from est_cnn import cnn_model_fn, train_input_fn
# Imports
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string(
'checkpoint_path', None, 'Path of a specific checkpoint to finetune.')
FLAGS = tf.app.flags.FLAGS
tf.logging.set_verbosity(tf.logging.INFO)
def main(unused_argv):
tf.set_pruning_mode()
ws = tf.estimator.WarmStartSettings(
ckpt_to_initialize_from=FLAGS.checkpoint_path)
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="./models/ft/", warm_start_from=ws)
mnist_classifier.train(
input_fn=train_input_fn(),
max_steps=20000)
if __name__ == "__main__":
tf.app.run()
Use tf.estimator.WarmStartSettings to load pruned checkpoint and finetune from it.
Run ft.py to finetune the pruned model:
python -u ft.py --checkpoint_path=${PRUNED_CKPT}
The output log looks like the following:
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./models/ft/model.ckpt.
INFO:tensorflow:loss = 0.3675258, step = 0
INFO:tensorflow:global_step/sec: 162.673
INFO:tensorflow:loss = 0.31534952, step = 100 (0.615 sec)
INFO:tensorflow:global_step/sec: 210.058
INFO:tensorflow:loss = 0.2782951, step = 200 (0.476 sec)
...
INFO:tensorflow:loss = 0.022076223, step = 19800 (0.503 sec)
INFO:tensorflow:global_step/sec: 206.588
INFO:tensorflow:loss = 0.06927078, step = 19900 (0.484 sec)
INFO:tensorflow:Saving checkpoints for 20000 into ./models/ft/model.ckpt.
INFO:tensorflow:Loss for final step: 0.07726018.
As a final step, transform and freeze the finetuned model to get a dense model.
FT_CKPT=${WORKSPACE}/ft/model.ckpt-20000
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb
vai_p_tensorflow \
--action=transform \
--input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}
freeze_graph \
--input_graph="${PRUNED_GRAPH}" \
--input_checkpoint="${TRANSFORMED_CKPT}" \
--input_binary=false \
--output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}
Finally, you have a frozen GraphDef file named mninst.pb.