📚 The CoCalc Library - books, templates and other resources
License: OTHER
""" Using convolutional net on MNIST dataset of handwritten digits1MNIST dataset: http://yann.lecun.com/exdb/mnist/2CS 20: "TensorFlow for Deep Learning Research"3cs20.stanford.edu4Chip Huyen ([email protected])5Lecture 076"""7import os8os.environ['TF_CPP_MIN_LOG_LEVEL']='2'9import time1011import tensorflow as tf1213import utils1415def conv_relu(inputs, filters, k_size, stride, padding, scope_name):16'''17A method that does convolution + relu on inputs18'''19with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:20in_channels = inputs.shape[-1]21kernel = tf.get_variable('kernel',22[k_size, k_size, in_channels, filters],23initializer=tf.truncated_normal_initializer())24biases = tf.get_variable('biases',25[filters],26initializer=tf.random_normal_initializer())27conv = tf.nn.conv2d(inputs, kernel, strides=[1, stride, stride, 1], padding=padding)28return tf.nn.relu(conv + biases, name=scope.name)2930def maxpool(inputs, ksize, stride, padding='VALID', scope_name='pool'):31'''A method that does max pooling on inputs'''32with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:33pool = tf.nn.max_pool(inputs,34ksize=[1, ksize, ksize, 1],35strides=[1, stride, stride, 1],36padding=padding)37return pool3839def fully_connected(inputs, out_dim, scope_name='fc'):40'''41A fully connected linear layer on inputs42'''43with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:44in_dim = inputs.shape[-1]45w = tf.get_variable('weights', [in_dim, out_dim],46initializer=tf.truncated_normal_initializer())47b = tf.get_variable('biases', [out_dim],48initializer=tf.constant_initializer(0.0))49out = tf.matmul(inputs, w) + b50return out5152class ConvNet(object):53def __init__(self):54self.lr = 0.00155self.batch_size = 12856self.keep_prob = tf.constant(0.75)57self.gstep = tf.Variable(0, dtype=tf.int32,58trainable=False, name='global_step')59self.n_classes = 1060self.skip_step = 2061self.n_test = 1000062self.training = True6364def get_data(self):65with tf.name_scope('data'):66train_data, test_data = utils.get_mnist_dataset(self.batch_size)67iterator = tf.data.Iterator.from_structure(train_data.output_types,68train_data.output_shapes)69img, self.label = iterator.get_next()70self.img = tf.reshape(img, shape=[-1, 28, 28, 1])71# reshape the image to make it work with tf.nn.conv2d7273self.train_init = iterator.make_initializer(train_data) # initializer for train_data74self.test_init = iterator.make_initializer(test_data) # initializer for train_data7576def inference(self):77conv1 = conv_relu(inputs=self.img,78filters=32,79k_size=5,80stride=1,81padding='SAME',82scope_name='conv1')83pool1 = maxpool(conv1, 2, 2, 'VALID', 'pool1')84conv2 = conv_relu(inputs=pool1,85filters=64,86k_size=5,87stride=1,88padding='SAME',89scope_name='conv2')90pool2 = maxpool(conv2, 2, 2, 'VALID', 'pool2')91feature_dim = pool2.shape[1] * pool2.shape[2] * pool2.shape[3]92pool2 = tf.reshape(pool2, [-1, feature_dim])93fc = fully_connected(pool2, 1024, 'fc')94dropout = tf.nn.dropout(tf.nn.relu(fc), self.keep_prob, name='relu_dropout')95self.logits = fully_connected(dropout, self.n_classes, 'logits')9697def loss(self):98'''99define loss function100use softmax cross entropy with logits as the loss function101compute mean cross entropy, softmax is applied internally102'''103#104with tf.name_scope('loss'):105entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.label, logits=self.logits)106self.loss = tf.reduce_mean(entropy, name='loss')107108def optimize(self):109'''110Define training op111using Adam Gradient Descent to minimize cost112'''113self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss,114global_step=self.gstep)115116def summary(self):117'''118Create summaries to write on TensorBoard119'''120with tf.name_scope('summaries'):121tf.summary.scalar('loss', self.loss)122tf.summary.scalar('accuracy', self.accuracy)123tf.summary.histogram('histogram loss', self.loss)124self.summary_op = tf.summary.merge_all()125126def eval(self):127'''128Count the number of right predictions in a batch129'''130with tf.name_scope('predict'):131preds = tf.nn.softmax(self.logits)132correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))133self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))134135def build(self):136'''137Build the computation graph138'''139self.get_data()140self.inference()141self.loss()142self.optimize()143self.eval()144self.summary()145146def train_one_epoch(self, sess, saver, init, writer, epoch, step):147start_time = time.time()148sess.run(init)149self.training = True150total_loss = 0151n_batches = 0152try:153while True:154_, l, summaries = sess.run([self.opt, self.loss, self.summary_op])155writer.add_summary(summaries, global_step=step)156if (step + 1) % self.skip_step == 0:157print('Loss at step {0}: {1}'.format(step, l))158step += 1159total_loss += l160n_batches += 1161except tf.errors.OutOfRangeError:162pass163saver.save(sess, 'checkpoints/convnet_mnist/mnist-convnet', step)164print('Average loss at epoch {0}: {1}'.format(epoch, total_loss/n_batches))165print('Took: {0} seconds'.format(time.time() - start_time))166return step167168def eval_once(self, sess, init, writer, epoch, step):169start_time = time.time()170sess.run(init)171self.training = False172total_correct_preds = 0173try:174while True:175accuracy_batch, summaries = sess.run([self.accuracy, self.summary_op])176writer.add_summary(summaries, global_step=step)177total_correct_preds += accuracy_batch178except tf.errors.OutOfRangeError:179pass180181print('Accuracy at epoch {0}: {1} '.format(epoch, total_correct_preds/self.n_test))182print('Took: {0} seconds'.format(time.time() - start_time))183184def train(self, n_epochs):185'''186The train function alternates between training one epoch and evaluating187'''188utils.safe_mkdir('checkpoints')189utils.safe_mkdir('checkpoints/convnet_mnist')190writer = tf.summary.FileWriter('./graphs/convnet', tf.get_default_graph())191192with tf.Session() as sess:193sess.run(tf.global_variables_initializer())194saver = tf.train.Saver()195ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/convnet_mnist/checkpoint'))196if ckpt and ckpt.model_checkpoint_path:197saver.restore(sess, ckpt.model_checkpoint_path)198199step = self.gstep.eval()200201for epoch in range(n_epochs):202step = self.train_one_epoch(sess, saver, self.train_init, writer, epoch, step)203self.eval_once(sess, self.test_init, writer, epoch, step)204writer.close()205206if __name__ == '__main__':207model = ConvNet()208model.build()209model.train(n_epochs=30)210211212