モデルに学習させるには、train.py という名前のファイルを作成し、次のコードを追加します。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from train_eval_utils import ConvNet
tf.app.flags.DEFINE_string(
'save_ckpt', '', 'Where to save checkpoint.')
FLAGS = tf.app.flags.FLAGS
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info("Training model from scratch")
net = ConvNet(True)
net.build()
net.train_eval(10, FLAGS.save_ckpt)
if __name__ == '__main__':
tf.app.run()
シェルで train.py を実行します。
$ WORKSPACE=./models
$ BASELINE_CKPT=${WORKSPACE}/train/model.ckpt
$ mkdir -p $(dirname "${BASELINE_CKPT}")
$ python train.py --save_ckpt=${BASELINE_CKPT}
実行によって次のようなログが出力されます。
INFO:tensorflow:time:2019-01-09 16:14:44
INFO:tensorflow:Loss at step 500: 421.8246154785156
INFO:tensorflow:Loss at step 600: 305.761474609375
INFO:tensorflow:Loss at step 700: 167.25115966796875
INFO:tensorflow:Loss at step 800: 399.25732421875
INFO:tensorflow:Loss at step 900: 246.51300048828125
INFO:tensorflow:Average loss at epoch 1: 390.06004813383385
INFO:tensorflow:train one epoch took: 2.353825569152832 seconds
INFO:tensorflow:Evaluation took: 0.22740554809570312 seconds
INFO:tensorflow:Accuracy : 0.9435
数分後、学習済みチェックポイントが models/train/model.ckpt に生成されます。