Create a file called data_utils.py and add the following code:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip, os, sys
from six.moves import urllib
import numpy as np
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
# The URLs where the MNIST data can be downloaded.
_DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
_TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz'
_TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
_TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz'
_TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
_LABELS_FILENAME = 'labels.txt'
_DATASET_DIR = 'data/mnist'
_IMAGE_SIZE = 28
_NUM_CHANNELS = 1
_NUM_LABELS = 10
# The names of the classes.
_CLASS_NAMES = [
'zero',
'one',
'two',
'three',
'four',
'five',
'size',
'seven',
'eight',
'nine',
]
def _extract_images(filename, num_images):
"""Extract the images into a numpy array.
Args:
filename: The path to an MNIST images file.
num_images: The number of images in the file.
Returns:
A numpy array of shape [number_of_images, height, width, channels].
"""
print('Extracting images from: ', filename)
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(
_IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
return data
def _extract_labels(filename, num_labels):
"""Extract the labels into a vector of int64 label IDs.
Args:
filename: The path to an MNIST labels file.
num_labels: The number of labels in the file.
Returns:
A numpy array of shape [number_of_labels]
"""
print('Extracting labels from: ', filename)
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_labels)
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
return labels
def int64_feature(values):
"""Returns a TF-Feature of int64s.
Args:
values: A scalar or list of values.
Returns:
A TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def _image_to_tfexample(image_data, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/class/label': int64_feature(class_id)
}))
def _add_to_tfrecord(data_filename, labels_filename, num_images,
tfrecord_writer):
"""Loads data from the binary MNIST files and writes files to a TFRecord.
Args:
data_filename: The filename of the MNIST images.
labels_filename: The filename of the MNIST labels.
num_images: The number of images in the dataset.
tfrecord_writer: The TFRecord writer to use for writing.
"""
images = _extract_images(data_filename, num_images)
labels = _extract_labels(labels_filename, num_images)
shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
with tf.Graph().as_default():
image = tf.placeholder(dtype=tf.uint8, shape=shape)
encoded_png = tf.image.encode_png(image)
with tf.Session('') as sess:
for j in range(num_images):
sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
sys.stdout.flush()
png_string = sess.run(encoded_png, feed_dict={image: images[j]})
example = _image_to_tfexample(png_string, labels[j])
tfrecord_writer.write(example.SerializeToString())
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name)
def _download_dataset(dataset_dir):
"""Downloads MNIST locally.
Args:
dataset_dir: The directory where the temporary files are stored.
"""
for filename in [_TRAIN_DATA_FILENAME,
_TRAIN_LABELS_FILENAME,
_TEST_DATA_FILENAME,
_TEST_LABELS_FILENAME]:
filepath = os.path.join(dataset_dir, filename)
if not os.path.exists(filepath):
print('Downloading file %s...' % filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %.1f%%' % (
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename,
filepath,
_progress)
print()
with tf.gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded', filename, size, 'bytes.')
def _write_label_file(labels_to_class_names, dataset_dir,
filename=_LABELS_FILENAME):
"""Writes a file with the list of class names.
Args:
labels_to_class_names: A map of (integer) labels to class names.
dataset_dir: The directory in which the labels file should be written.
filename: The filename where the class names are written.
"""
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name))
def _clean_up_temporary_files(dataset_dir):
"""Removes temporary files used to create the dataset.
Args:
dataset_dir: The directory where the temporary files are stored.
"""
for filename in [_TRAIN_DATA_FILENAME,
_TRAIN_LABELS_FILENAME,
_TEST_DATA_FILENAME,
_TEST_LABELS_FILENAME]:
filepath = os.path.join(dataset_dir, filename)
tf.gfile.Remove(filepath)
def download_and_convert(dataset_dir, clean=False):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
training_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
_download_dataset(dataset_dir)
# First, process the training data:
with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME)
labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME)
_add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
# Next, process the testing data:
with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME)
labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME)
_add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
_write_label_file(labels_to_class_names, dataset_dir)
if clean:
_clean_up_temporary_files(dataset_dir)
print('\nFinished converting the MNIST dataset!')
def _parse_function(tfrecord_serialized):
"""Parse TFRecord serialized object into image and label with specified shape
and data type.
Args:
TFRecord_serialized: tf.data.TFRecordDataset.
Returns:
Parsed image and label
"""
features = {'image/encoded': tf.FixedLenFeature([], tf.string),
'image/class/label': tf.FixedLenFeature([], tf.int64)}
parsed_features = tf.parse_single_example(tfrecord_serialized, features)
image = parsed_features['image/encoded']
label = parsed_features['image/class/label']
image = tf.image.decode_png(image)
image = tf.divide(image, 255)
return image, label
def get_init_data(train_batch,
test_batch,
dataset_dir=_DATASET_DIR,
test_only=False,
num_parallel_calls=8):
"""Build input data pipline, which must be initial by sess.run(init)
Args:
train_batch: batch size of train data set
test_batch: batch size of test data set
dataset_dir: Optional. Where to store data set
test_only: If only build test data input pipline set
num_parallel_calls: number of parallel read data
Returns:
img: input image data tensor
label: input label data tensor
train_init: train data initializer
test_init:test data initializer
"""
with tf.name_scope('data'):
testing_filename = _get_output_filename(dataset_dir, 'test')
test_data = tf.data.TFRecordDataset(testing_filename)
test_data = test_data.map(_parse_function, \
num_parallel_calls=num_parallel_calls)
test_data = test_data.batch(test_batch)
test_data = test_data.prefetch(test_batch)
iterator = tf.data.Iterator.from_structure(test_data.output_types,
test_data.output_shapes)
test_init = iterator.make_initializer(test_data) # initializer for train_data
img, label = iterator.get_next()
# reshape the image from [28,28,1], to make it work with tf.nn.conv2d
img = tf.reshape(img, shape=[-1, _IMAGE_SIZE , _IMAGE_SIZE , _NUM_CHANNELS])
label = tf.one_hot(label, _NUM_LABELS)
train_init = None
if not test_only:
training_filename = _get_output_filename(dataset_dir, 'train')
train_data = tf.data.TFRecordDataset([training_filename])
train_data = train_data.shuffle(10000)
train_data = train_data.map(_parse_function,\
num_parallel_calls=num_parallel_calls)
train_data = train_data.batch(train_batch)
train_data = train_data.prefetch(train_batch)
train_init = iterator.make_initializer(train_data) # initializer for train_data
return img, label, train_init, test_init
def get_one_shot_test_data(
test_batch,
dataset_dir=_DATASET_DIR,
num_parallel_calls=8):
"""Build input test data pipline, which no need to be initial. For `vai_p_tensorflow
--ana`
Args:
test_batch: batch size of test data set
dataset_dir: Optional. Where to store data set
num_parallel_calls: number of parallel read data
Returns:
img: input image data tensor
label: input label data tensor
"""
#do not need initial
with tf.name_scope('data'):
testing_filename = _get_output_filename(dataset_dir, 'test')
test_data = tf.data.TFRecordDataset([testing_filename])
test_data = test_data.map(_parse_function,
num_parallel_calls=num_parallel_calls)
test_data = test_data.batch(test_batch)
test_data = test_data.prefetch(test_batch)
iterator = test_data.make_one_shot_iterator()
img, label = iterator.get_next()
# reshape the image from [28,28,1] to make it work with tf.nn.conv2d
img = tf.reshape(img, shape=[-1, _IMAGE_SIZE , _IMAGE_SIZE , _NUM_CHANNELS])
label = tf.one_hot(label, _NUM_LABELS)
return img, label
if __name__ == '__main__':
download_and_convert(_DATASET_DIR)
The dataset_utils
function calls get_init_data
taking train_batch
and test_batch
as arguments. It returns an image, label
tensors, and initializer operations to train and test data, which will now run in training and
evaluation.
data_utils.py is imported as a module to provide an input data pipeline. You can also run it in a shell to download the MNIST dataset and convert it into TFRecord format using the following command:
$ python data_utils.py
This generates the following:
data/minist/label.txt
data/minist/mnist_test.tfrecord data/minist/mnist_train.tfrecord
data/minist/t10k-images-idx3-ubyte.gz
data/minist/t10k-labels-idx1-ubyte.gz
data/minist/train-images-idx3-ubyte.gz
data/minist/train-labels-idx1-ubyte.gz