📚 The CoCalc Library - books, templates and other resources
cocalc-examples / data-science-ipython-notebooks / deep-learning / tensor-flow-examples / input_data.py
132928 viewsLicense: OTHER
"""Functions for downloading and reading MNIST data."""1from __future__ import print_function2import gzip3import os4import urllib5import numpy6SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'7def maybe_download(filename, work_directory):8"""Download the data from Yann's website, unless it's already here."""9if not os.path.exists(work_directory):10os.mkdir(work_directory)11filepath = os.path.join(work_directory, filename)12if not os.path.exists(filepath):13filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)14statinfo = os.stat(filepath)15print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')16return filepath17def _read32(bytestream):18dt = numpy.dtype(numpy.uint32).newbyteorder('>')19return numpy.frombuffer(bytestream.read(4), dtype=dt)20def extract_images(filename):21"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""22print('Extracting', filename)23with gzip.open(filename) as bytestream:24magic = _read32(bytestream)25if magic != 2051:26raise ValueError(27'Invalid magic number %d in MNIST image file: %s' %28(magic, filename))29num_images = _read32(bytestream)30rows = _read32(bytestream)31cols = _read32(bytestream)32buf = bytestream.read(rows * cols * num_images)33data = numpy.frombuffer(buf, dtype=numpy.uint8)34data = data.reshape(num_images, rows, cols, 1)35return data36def dense_to_one_hot(labels_dense, num_classes=10):37"""Convert class labels from scalars to one-hot vectors."""38num_labels = labels_dense.shape[0]39index_offset = numpy.arange(num_labels) * num_classes40labels_one_hot = numpy.zeros((num_labels, num_classes))41labels_one_hot.flat[index_offset + labels_dense.ravel()] = 142return labels_one_hot43def extract_labels(filename, one_hot=False):44"""Extract the labels into a 1D uint8 numpy array [index]."""45print('Extracting', filename)46with gzip.open(filename) as bytestream:47magic = _read32(bytestream)48if magic != 2049:49raise ValueError(50'Invalid magic number %d in MNIST label file: %s' %51(magic, filename))52num_items = _read32(bytestream)53buf = bytestream.read(num_items)54labels = numpy.frombuffer(buf, dtype=numpy.uint8)55if one_hot:56return dense_to_one_hot(labels)57return labels58class DataSet(object):59def __init__(self, images, labels, fake_data=False):60if fake_data:61self._num_examples = 1000062else:63assert images.shape[0] == labels.shape[0], (64"images.shape: %s labels.shape: %s" % (images.shape,65labels.shape))66self._num_examples = images.shape[0]67# Convert shape from [num examples, rows, columns, depth]68# to [num examples, rows*columns] (assuming depth == 1)69assert images.shape[3] == 170images = images.reshape(images.shape[0],71images.shape[1] * images.shape[2])72# Convert from [0, 255] -> [0.0, 1.0].73images = images.astype(numpy.float32)74images = numpy.multiply(images, 1.0 / 255.0)75self._images = images76self._labels = labels77self._epochs_completed = 078self._index_in_epoch = 079@property80def images(self):81return self._images82@property83def labels(self):84return self._labels85@property86def num_examples(self):87return self._num_examples88@property89def epochs_completed(self):90return self._epochs_completed91def next_batch(self, batch_size, fake_data=False):92"""Return the next `batch_size` examples from this data set."""93if fake_data:94fake_image = [1.0 for _ in xrange(784)]95fake_label = 096return [fake_image for _ in xrange(batch_size)], [97fake_label for _ in xrange(batch_size)]98start = self._index_in_epoch99self._index_in_epoch += batch_size100if self._index_in_epoch > self._num_examples:101# Finished epoch102self._epochs_completed += 1103# Shuffle the data104perm = numpy.arange(self._num_examples)105numpy.random.shuffle(perm)106self._images = self._images[perm]107self._labels = self._labels[perm]108# Start next epoch109start = 0110self._index_in_epoch = batch_size111assert batch_size <= self._num_examples112end = self._index_in_epoch113return self._images[start:end], self._labels[start:end]114def read_data_sets(train_dir, fake_data=False, one_hot=False):115class DataSets(object):116pass117data_sets = DataSets()118if fake_data:119data_sets.train = DataSet([], [], fake_data=True)120data_sets.validation = DataSet([], [], fake_data=True)121data_sets.test = DataSet([], [], fake_data=True)122return data_sets123TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'124TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'125TEST_IMAGES = 't10k-images-idx3-ubyte.gz'126TEST_LABELS = 't10k-labels-idx1-ubyte.gz'127VALIDATION_SIZE = 5000128local_file = maybe_download(TRAIN_IMAGES, train_dir)129train_images = extract_images(local_file)130local_file = maybe_download(TRAIN_LABELS, train_dir)131train_labels = extract_labels(local_file, one_hot=one_hot)132local_file = maybe_download(TEST_IMAGES, train_dir)133test_images = extract_images(local_file)134local_file = maybe_download(TEST_LABELS, train_dir)135test_labels = extract_labels(local_file, one_hot=one_hot)136validation_images = train_images[:VALIDATION_SIZE]137validation_labels = train_labels[:VALIDATION_SIZE]138train_images = train_images[VALIDATION_SIZE:]139train_labels = train_labels[VALIDATION_SIZE:]140data_sets.train = DataSet(train_images, train_labels)141data_sets.validation = DataSet(validation_images, validation_labels)142data_sets.test = DataSet(test_images, test_labels)143return data_sets144145