📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" Using convolutional net on MNIST dataset of handwritten digits1MNIST dataset: http://yann.lecun.com/exdb/mnist/2CS 20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Chip Huyen ([email protected])5Lecture 076"""7import os8os.environ['TF_CPP_MIN_LOG_LEVEL']='2'9import time1011import tensorflow as tf1213import utils1415class ConvNet(object):16def __init__(self):17self.lr = 0.00118self.batch_size = 12819self.keep_prob = tf.constant(0.75)20self.gstep = tf.Variable(0, dtype=tf.int32,21trainable=False, name='global_step')22self.n_classes = 1023self.skip_step = 2024self.n_test = 1000025self.training=False2627def get_data(self):28with tf.name_scope('data'):29train_data, test_data = utils.get_mnist_dataset(self.batch_size)30iterator = tf.data.Iterator.from_structure(train_data.output_types,31train_data.output_shapes)32img, self.label = iterator.get_next()33self.img = tf.reshape(img, shape=[-1, 28, 28, 1])34# reshape the image to make it work with tf.nn.conv2d3536self.train_init = iterator.make_initializer(train_data) # initializer for train_data37self.test_init = iterator.make_initializer(test_data) # initializer for train_data3839def inference(self):40conv1 = tf.layers.conv2d(inputs=self.img,41filters=32,42kernel_size=[5, 5],43padding='SAME',44activation=tf.nn.relu,45name='conv1')46pool1 = tf.layers.max_pooling2d(inputs=conv1,47pool_size=[2, 2],48strides=2,49name='pool1')5051conv2 = tf.layers.conv2d(inputs=pool1,52filters=64,53kernel_size=[5, 5],54padding='SAME',55activation=tf.nn.relu,56name='conv2')57pool2 = tf.layers.max_pooling2d(inputs=conv2,58pool_size=[2, 2],59strides=2,60name='pool2')6162feature_dim = pool2.shape[1] * pool2.shape[2] * pool2.shape[3]63pool2 = tf.reshape(pool2, [-1, feature_dim])64fc = tf.layers.dense(pool2, 1024, activation=tf.nn.relu, name='fc')65dropout = tf.layers.dropout(fc,66self.keep_prob,67training=self.training,68name='dropout')69self.logits = tf.layers.dense(dropout, self.n_classes, name='logits')7071def loss(self):72'''73define loss function74use softmax cross entropy with logits as the loss function75compute mean cross entropy, softmax is applied internally76'''77#78with tf.name_scope('loss'):79entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.label, logits=self.logits)80self.loss = tf.reduce_mean(entropy, name='loss')8182def optimize(self):83'''84Define training op85using Adam Gradient Descent to minimize cost86'''87self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss,88global_step=self.gstep)8990def summary(self):91'''92Create summaries to write on TensorBoard93'''94with tf.name_scope('summaries'):95tf.summary.scalar('loss', self.loss)96tf.summary.scalar('accuracy', self.accuracy)97tf.summary.histogram('histogram loss', self.loss)98self.summary_op = tf.summary.merge_all()99100def eval(self):101'''102Count the number of right predictions in a batch103'''104with tf.name_scope('predict'):105preds = tf.nn.softmax(self.logits)106correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))107self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))108109def build(self):110'''111Build the computation graph112'''113self.get_data()114self.inference()115self.loss()116self.optimize()117self.eval()118self.summary()119120def train_one_epoch(self, sess, saver, init, writer, epoch, step):121start_time = time.time()122sess.run(init)123self.training = True124total_loss = 0125n_batches = 0126try:127while True:128_, l, summaries = sess.run([self.opt, self.loss, self.summary_op])129writer.add_summary(summaries, global_step=step)130if (step + 1) % self.skip_step == 0:131print('Loss at step {0}: {1}'.format(step, l))132step += 1133total_loss += l134n_batches += 1135except tf.errors.OutOfRangeError:136pass137saver.save(sess, 'checkpoints/convnet_layers/mnist-convnet', step)138print('Average loss at epoch {0}: {1}'.format(epoch, total_loss/n_batches))139print('Took: {0} seconds'.format(time.time() - start_time))140return step141142def eval_once(self, sess, init, writer, epoch, step):143start_time = time.time()144sess.run(init)145self.training = False146total_correct_preds = 0147try:148while True:149accuracy_batch, summaries = sess.run([self.accuracy, self.summary_op])150writer.add_summary(summaries, global_step=step)151total_correct_preds += accuracy_batch152except tf.errors.OutOfRangeError:153pass154155print('Accuracy at epoch {0}: {1} '.format(epoch, total_correct_preds/self.n_test))156print('Took: {0} seconds'.format(time.time() - start_time))157158def train(self, n_epochs):159'''160The train function alternates between training one epoch and evaluating161'''162utils.safe_mkdir('checkpoints')163utils.safe_mkdir('checkpoints/convnet_layers')164writer = tf.summary.FileWriter('./graphs/convnet_layers', tf.get_default_graph())165166with tf.Session() as sess:167sess.run(tf.global_variables_initializer())168saver = tf.train.Saver()169ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/convnet_layers/checkpoint'))170if ckpt and ckpt.model_checkpoint_path:171saver.restore(sess, ckpt.model_checkpoint_path)172173step = self.gstep.eval()174175for epoch in range(n_epochs):176step = self.train_one_epoch(sess, saver, self.train_init, writer, epoch, step)177self.eval_once(sess, self.test_init, writer, epoch, step)178writer.close()179180if __name__ == '__main__':181model = ConvNet()182model.build()183model.train(n_epochs=15)184185