📚 The CoCalc Library - books, templates and other resources
License: OTHER
import os1import gzip2import shutil3import struct4import urllib56os.environ['TF_CPP_MIN_LOG_LEVEL']='2'78from matplotlib import pyplot as plt9import numpy as np10import tensorflow as tf1112def huber_loss(labels, predictions, delta=14.0):13residual = tf.abs(labels - predictions)14def f1(): return 0.5 * tf.square(residual)15def f2(): return delta * residual - 0.5 * tf.square(delta)16return tf.cond(residual < delta, f1, f2)1718def safe_mkdir(path):19""" Create a directory if there isn't one already. """20try:21os.mkdir(path)22except OSError:23pass2425def read_birth_life_data(filename):26"""27Read in birth_life_2010.txt and return:28data in the form of NumPy array29n_samples: number of samples30"""31text = open(filename, 'r').readlines()[1:]32data = [line[:-1].split('\t') for line in text]33births = [float(line[1]) for line in data]34lifes = [float(line[2]) for line in data]35data = list(zip(births, lifes))36n_samples = len(data)37data = np.asarray(data, dtype=np.float32)38return data, n_samples3940def download_one_file(download_url,41local_dest,42expected_byte=None,43unzip_and_remove=False):44"""45Download the file from download_url into local_dest46if the file doesn't already exists.47If expected_byte is provided, check if48the downloaded file has the same number of bytes.49If unzip_and_remove is True, unzip the file and remove the zip file50"""51if os.path.exists(local_dest) or os.path.exists(local_dest[:-3]):52print('%s already exists' %local_dest)53else:54print('Downloading %s' %download_url)55local_file, _ = urllib.request.urlretrieve(download_url, local_dest)56file_stat = os.stat(local_dest)57if expected_byte:58if file_stat.st_size == expected_byte:59print('Successfully downloaded %s' %local_dest)60if unzip_and_remove:61with gzip.open(local_dest, 'rb') as f_in, open(local_dest[:-3],'wb') as f_out:62shutil.copyfileobj(f_in, f_out)63os.remove(local_dest)64else:65print('The downloaded file has unexpected number of bytes')6667def download_mnist(path):68"""69Download and unzip the dataset mnist if it's not already downloaded70Download from http://yann.lecun.com/exdb/mnist71"""72safe_mkdir(path)73url = 'http://yann.lecun.com/exdb/mnist'74filenames = ['train-images-idx3-ubyte.gz',75'train-labels-idx1-ubyte.gz',76't10k-images-idx3-ubyte.gz',77't10k-labels-idx1-ubyte.gz']78expected_bytes = [9912422, 28881, 1648877, 4542]7980for filename, byte in zip(filenames, expected_bytes):81download_url = os.path.join(url, filename)82local_dest = os.path.join(path, filename)83download_one_file(download_url, local_dest, byte, True)8485def parse_data(path, dataset, flatten):86if dataset != 'train' and dataset != 't10k':87raise NameError('dataset must be train or t10k')8889label_file = os.path.join(path, dataset + '-labels-idx1-ubyte')90with open(label_file, 'rb') as file:91_, num = struct.unpack(">II", file.read(8))92labels = np.fromfile(file, dtype=np.int8) #int893new_labels = np.zeros((num, 10))94new_labels[np.arange(num), labels] = 19596img_file = os.path.join(path, dataset + '-images-idx3-ubyte')97with open(img_file, 'rb') as file:98_, num, rows, cols = struct.unpack(">IIII", file.read(16))99imgs = np.fromfile(file, dtype=np.uint8).reshape(num, rows, cols) #uint8100imgs = imgs.astype(np.float32) / 255.0101if flatten:102imgs = imgs.reshape([num, -1])103104return imgs, new_labels105106def read_mnist(path, flatten=True, num_train=55000):107"""108Read in the mnist dataset, given that the data is stored in path109Return two tuples of numpy arrays110((train_imgs, train_labels), (test_imgs, test_labels))111"""112imgs, labels = parse_data(path, 'train', flatten)113indices = np.random.permutation(labels.shape[0])114train_idx, val_idx = indices[:num_train], indices[num_train:]115train_img, train_labels = imgs[train_idx, :], labels[train_idx, :]116val_img, val_labels = imgs[val_idx, :], labels[val_idx, :]117test = parse_data(path, 't10k', flatten)118return (train_img, train_labels), (val_img, val_labels), test119120def get_mnist_dataset(batch_size):121# Step 1: Read in data122mnist_folder = 'data/mnist'123download_mnist(mnist_folder)124train, val, test = read_mnist(mnist_folder, flatten=False)125126# Step 2: Create datasets and iterator127train_data = tf.data.Dataset.from_tensor_slices(train)128train_data = train_data.shuffle(10000) # if you want to shuffle your data129train_data = train_data.batch(batch_size)130131test_data = tf.data.Dataset.from_tensor_slices(test)132test_data = test_data.batch(batch_size)133134return train_data, test_data135136def show(image):137"""138Render a given numpy.uint8 2D array of pixel data.139"""140plt.imshow(image, cmap='gray')141plt.show()142143