Create a file named ft.py and add the following code:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from train_eval_utils import ConvNet
tf.app.flags.DEFINE_string(
'checkpoint_path', '', 'Where to restore checkpoint.')
tf.app.flags.DEFINE_string(
'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info("Finetuning model")
tf.set_pruning_mode()
net = ConvNet(True)
net.build()
net.train_eval(10, FLAGS.save_ckpt, FLAGS.checkpoint_path)
if __name__ == '__main__':
tf.app.run()
Note: You must call
tf.set_pruning_mode()
before creating the model. The API is
used to enable the sparse training mode, that is, the weights of pruned channels are set
to zero and not updated during training. If you fine-tune a pruned model without calling
this function, the pruned channels will be updated and finally you will get a normal
non-sparse model.Fine-tune the pruned model and run ft.py:
WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
PRUNED_CKPT=${WORKSPACE}/pruned/sparse.ckpt
python -u ft.py \
--save_ckpt=${FT_CKPT} \
--checkpoint_path=${PRUNED_CKPT} \
2>&1 | tee ft.log
The output log looks like:
INFO:tensorflow:time:2019-01-09 17:17:10
INFO:tensorflow:Loss at step 1000: 13.077235221862793
INFO:tensorflow:Loss at step 1100: 41.67073440551758
INFO:tensorflow:Loss at step 1200: 31.98809242248535
INFO:tensorflow:Loss at step 1300: 34.46034240722656
INFO:tensorflow:Loss at step 1400: 32.12882995605469
INFO:tensorflow:Average loss at epoch 2: 28.96098704302489
INFO:tensorflow:train one epoch took: 3.0082509517669678 seconds
INFO:tensorflow:Evaluation took: 0.23403644561767578 seconds
INFO:tensorflow:Accuracy : 0.9539
As a final step, transform and freeze the fine-tuned model to get a dense model.
WORKSPACE=./models
FT_CKPT=${WORKSPACE}/ft/model.ckpt
TRANSFORMED_CKPT=${WORKSPACE}/pruned/transformed.ckpt
PRUNED_GRAPH=${WORKSPACE}/pruned/graph.pbtxt
FROZEN_PB=${WORKSPACE}/pruned/mnist.pb
OUTPUT_NODES="logits/add"
vai_p_tensorflow \
--action=transform \
--input_ckpt=${FT_CKPT} \
--output_ckpt=${TRANSFORMED_CKPT}
freeze_graph \
--input_graph="${PRUNED_GRAPH}" \
--input_checkpoint="${TRANSFORMED_CKPT}" \
--input_binary=false \
--output_graph="${FROZEN_PB}" \
--output_node_names=${OUTPUT_NODES}
Finally, you should have a frozen GraphDef file named mninst.pb in the models/pruned location.