Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
"""Functions for downloading and reading MNIST data."""
2
from __future__ import print_function
3
import gzip
4
import os
5
import urllib
6
import numpy
7
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
8
def maybe_download(filename, work_directory):
9
"""Download the data from Yann's website, unless it's already here."""
10
if not os.path.exists(work_directory):
11
os.mkdir(work_directory)
12
filepath = os.path.join(work_directory, filename)
13
if not os.path.exists(filepath):
14
filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
15
statinfo = os.stat(filepath)
16
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
17
return filepath
18
def _read32(bytestream):
19
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
20
return numpy.frombuffer(bytestream.read(4), dtype=dt)
21
def extract_images(filename):
22
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
23
print('Extracting', filename)
24
with gzip.open(filename) as bytestream:
25
magic = _read32(bytestream)
26
if magic != 2051:
27
raise ValueError(
28
'Invalid magic number %d in MNIST image file: %s' %
29
(magic, filename))
30
num_images = _read32(bytestream)
31
rows = _read32(bytestream)
32
cols = _read32(bytestream)
33
buf = bytestream.read(rows * cols * num_images)
34
data = numpy.frombuffer(buf, dtype=numpy.uint8)
35
data = data.reshape(num_images, rows, cols, 1)
36
return data
37
def dense_to_one_hot(labels_dense, num_classes=10):
38
"""Convert class labels from scalars to one-hot vectors."""
39
num_labels = labels_dense.shape[0]
40
index_offset = numpy.arange(num_labels) * num_classes
41
labels_one_hot = numpy.zeros((num_labels, num_classes))
42
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
43
return labels_one_hot
44
def extract_labels(filename, one_hot=False):
45
"""Extract the labels into a 1D uint8 numpy array [index]."""
46
print('Extracting', filename)
47
with gzip.open(filename) as bytestream:
48
magic = _read32(bytestream)
49
if magic != 2049:
50
raise ValueError(
51
'Invalid magic number %d in MNIST label file: %s' %
52
(magic, filename))
53
num_items = _read32(bytestream)
54
buf = bytestream.read(num_items)
55
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
56
if one_hot:
57
return dense_to_one_hot(labels)
58
return labels
59
class DataSet(object):
60
def __init__(self, images, labels, fake_data=False):
61
if fake_data:
62
self._num_examples = 10000
63
else:
64
assert images.shape[0] == labels.shape[0], (
65
"images.shape: %s labels.shape: %s" % (images.shape,
66
labels.shape))
67
self._num_examples = images.shape[0]
68
# Convert shape from [num examples, rows, columns, depth]
69
# to [num examples, rows*columns] (assuming depth == 1)
70
assert images.shape[3] == 1
71
images = images.reshape(images.shape[0],
72
images.shape[1] * images.shape[2])
73
# Convert from [0, 255] -> [0.0, 1.0].
74
images = images.astype(numpy.float32)
75
images = numpy.multiply(images, 1.0 / 255.0)
76
self._images = images
77
self._labels = labels
78
self._epochs_completed = 0
79
self._index_in_epoch = 0
80
@property
81
def images(self):
82
return self._images
83
@property
84
def labels(self):
85
return self._labels
86
@property
87
def num_examples(self):
88
return self._num_examples
89
@property
90
def epochs_completed(self):
91
return self._epochs_completed
92
def next_batch(self, batch_size, fake_data=False):
93
"""Return the next `batch_size` examples from this data set."""
94
if fake_data:
95
fake_image = [1.0 for _ in xrange(784)]
96
fake_label = 0
97
return [fake_image for _ in xrange(batch_size)], [
98
fake_label for _ in xrange(batch_size)]
99
start = self._index_in_epoch
100
self._index_in_epoch += batch_size
101
if self._index_in_epoch > self._num_examples:
102
# Finished epoch
103
self._epochs_completed += 1
104
# Shuffle the data
105
perm = numpy.arange(self._num_examples)
106
numpy.random.shuffle(perm)
107
self._images = self._images[perm]
108
self._labels = self._labels[perm]
109
# Start next epoch
110
start = 0
111
self._index_in_epoch = batch_size
112
assert batch_size <= self._num_examples
113
end = self._index_in_epoch
114
return self._images[start:end], self._labels[start:end]
115
def read_data_sets(train_dir, fake_data=False, one_hot=False):
116
class DataSets(object):
117
pass
118
data_sets = DataSets()
119
if fake_data:
120
data_sets.train = DataSet([], [], fake_data=True)
121
data_sets.validation = DataSet([], [], fake_data=True)
122
data_sets.test = DataSet([], [], fake_data=True)
123
return data_sets
124
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
125
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
126
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
127
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
128
VALIDATION_SIZE = 5000
129
local_file = maybe_download(TRAIN_IMAGES, train_dir)
130
train_images = extract_images(local_file)
131
local_file = maybe_download(TRAIN_LABELS, train_dir)
132
train_labels = extract_labels(local_file, one_hot=one_hot)
133
local_file = maybe_download(TEST_IMAGES, train_dir)
134
test_images = extract_images(local_file)
135
local_file = maybe_download(TEST_LABELS, train_dir)
136
test_labels = extract_labels(local_file, one_hot=one_hot)
137
validation_images = train_images[:VALIDATION_SIZE]
138
validation_labels = train_labels[:VALIDATION_SIZE]
139
train_images = train_images[VALIDATION_SIZE:]
140
train_labels = train_labels[VALIDATION_SIZE:]
141
data_sets.train = DataSet(train_images, train_labels)
142
data_sets.validation = DataSet(validation_images, validation_labels)
143
data_sets.test = DataSet(test_images, test_labels)
144
return data_sets
145