📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" Using convolutional net on MNIST dataset of handwritten digits1MNIST dataset: http://yann.lecun.com/exdb/mnist/2CS 20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Chip Huyen ([email protected])5Lecture 076"""7import os8os.environ['TF_CPP_MIN_LOG_LEVEL']='2'9import time1011import tensorflow as tf1213import utils1415def conv_relu(inputs, filters, k_size, stride, padding, scope_name):16'''17A method that does convolution + relu on inputs18'''19#############################20########## TO DO ############21#############################22return None2324def maxpool(inputs, ksize, stride, padding='VALID', scope_name='pool'):25'''A method that does max pooling on inputs'''26#############################27########## TO DO ############28#############################29return None3031def fully_connected(inputs, out_dim, scope_name='fc'):32'''33A fully connected linear layer on inputs34'''35#############################36########## TO DO ############37#############################38return None3940class ConvNet(object):41def __init__(self):42self.lr = 0.00143self.batch_size = 12844self.keep_prob = tf.constant(0.75)45self.gstep = tf.Variable(0, dtype=tf.int32,46trainable=False, name='global_step')47self.n_classes = 1048self.skip_step = 2049self.n_test = 100005051def get_data(self):52with tf.name_scope('data'):53train_data, test_data = utils.get_mnist_dataset(self.batch_size)54iterator = tf.data.Iterator.from_structure(train_data.output_types,55train_data.output_shapes)56img, self.label = iterator.get_next()57self.img = tf.reshape(img, shape=[-1, 28, 28, 1])58# reshape the image to make it work with tf.nn.conv2d5960self.train_init = iterator.make_initializer(train_data) # initializer for train_data61self.test_init = iterator.make_initializer(test_data) # initializer for train_data6263def inference(self):64'''65Build the model according to the description we've shown in class66'''67#############################68########## TO DO ############69#############################70self.logits = None7172def loss(self):73'''74define loss function75use softmax cross entropy with logits as the loss function76tf.nn.softmax_cross_entropy_with_logits77softmax is applied internally78don't forget to compute mean cross all sample in a batch79'''80#############################81########## TO DO ############82#############################83self.loss = None8485def optimize(self):86'''87Define training op88using Adam Gradient Descent to minimize cost89Don't forget to use global step90'''91#############################92########## TO DO ############93#############################94self.opt = None9596def summary(self):97'''98Create summaries to write on TensorBoard99Remember to track both training loss and test accuracy100'''101#############################102########## TO DO ############103#############################104self.summary_op = None105106def eval(self):107'''108Count the number of right predictions in a batch109'''110with tf.name_scope('predict'):111preds = tf.nn.softmax(self.logits)112correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))113self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))114115def build(self):116'''117Build the computation graph118'''119self.get_data()120self.inference()121self.loss()122self.optimize()123self.eval()124self.summary()125126def train_one_epoch(self, sess, saver, init, writer, epoch, step):127start_time = time.time()128sess.run(init)129total_loss = 0130n_batches = 0131try:132while True:133_, l, summaries = sess.run([self.opt, self.loss, self.summary_op])134writer.add_summary(summaries, global_step=step)135if (step + 1) % self.skip_step == 0:136print('Loss at step {0}: {1}'.format(step, l))137step += 1138total_loss += l139n_batches += 1140except tf.errors.OutOfRangeError:141pass142saver.save(sess, 'checkpoints/convnet_starter/mnist-convnet', step)143print('Average loss at epoch {0}: {1}'.format(epoch, total_loss/n_batches))144print('Took: {0} seconds'.format(time.time() - start_time))145return step146147def eval_once(self, sess, init, writer, epoch, step):148start_time = time.time()149sess.run(init)150total_correct_preds = 0151try:152while True:153accuracy_batch, summaries = sess.run([self.accuracy, self.summary_op])154writer.add_summary(summaries, global_step=step)155total_correct_preds += accuracy_batch156except tf.errors.OutOfRangeError:157pass158159print('Accuracy at epoch {0}: {1} '.format(epoch, total_correct_preds/self.n_test))160print('Took: {0} seconds'.format(time.time() - start_time))161162def train(self, n_epochs):163'''164The train function alternates between training one epoch and evaluating165'''166utils.safe_mkdir('checkpoints')167utils.safe_mkdir('checkpoints/convnet_starter')168writer = tf.summary.FileWriter('./graphs/convnet_starter', tf.get_default_graph())169170with tf.Session() as sess:171sess.run(tf.global_variables_initializer())172saver = tf.train.Saver()173ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/convnet_starter/checkpoint'))174if ckpt and ckpt.model_checkpoint_path:175saver.restore(sess, ckpt.model_checkpoint_path)176177step = self.gstep.eval()178179for epoch in range(n_epochs):180step = self.train_one_epoch(sess, saver, self.train_init, writer, epoch, step)181self.eval_once(sess, self.test_init, writer, epoch, step)182writer.close()183184if __name__ == '__main__':185model = ConvNet()186model.build()187model.train(n_epochs=15)188189