📚 The CoCalc Library - books, templates and other resources
cocalc-examples / stanford-tensorflow-tutorials / assignments / 02_style_transfer / style_transfer_sol.py
132928 viewsLicense: OTHER
import os1os.environ['TF_CPP_MIN_LOG_LEVEL']='2'2import time34import numpy as np5import tensorflow as tf67import load_vgg_sol8import utils910def setup():11utils.safe_mkdir('checkpoints')12utils.safe_mkdir('outputs')1314class StyleTransfer(object):15def __init__(self, content_img, style_img, img_width, img_height):16'''17img_width and img_height are the dimensions we expect from the generated image.18We will resize input content image and input style image to match this dimension.19Feel free to alter any hyperparameter here and see how it affects your training.20'''21self.img_width = img_width22self.img_height = img_height23self.content_img = utils.get_resized_image(content_img, img_width, img_height)24self.style_img = utils.get_resized_image(style_img, img_width, img_height)25self.initial_img = utils.generate_noise_image(self.content_img, img_width, img_height)2627###############################28## TO DO29## create global step (gstep) and hyperparameters for the model30self.content_layer = 'conv4_2'31self.style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']32self.content_w = 0.0133self.style_w = 134self.style_layer_w = [0.5, 1.0, 1.5, 3.0, 4.0]35self.gstep = tf.Variable(0, dtype=tf.int32,36trainable=False, name='global_step')37self.lr = 2.038###############################3940def create_input(self):41'''42We will use one input_img as a placeholder for the content image,43style image, and generated image, because:441. they have the same dimension452. we have to extract the same set of features from them46We use a variable instead of a placeholder because we're, at the same time,47training the generated image to get the desirable result.4849Note: image height corresponds to number of rows, not columns.50'''51with tf.variable_scope('input') as scope:52self.input_img = tf.get_variable('in_img',53shape=([1, self.img_height, self.img_width, 3]),54dtype=tf.float32,55initializer=tf.zeros_initializer())56def load_vgg(self):57'''58Load the saved model parameters of VGG-19, using the input_img59as the input to compute the output at each layer of vgg.6061During training, VGG-19 mean-centered all images and found the mean pixels62to be [123.68, 116.779, 103.939] along RGB dimensions. We have to subtract63this mean from our images.6465'''66self.vgg = load_vgg_sol.VGG(self.input_img)67self.vgg.load()68self.content_img -= self.vgg.mean_pixels69self.style_img -= self.vgg.mean_pixels7071def _content_loss(self, P, F):72''' Calculate the loss between the feature representation of the73content image and the generated image.7475Inputs:76P: content representation of the content image77F: content representation of the generated image78Read the assignment handout for more details7980Note: Don't use the coefficient 0.5 as defined in the paper.81Use the coefficient defined in the assignment handout.82'''83# self.content_loss = None84###############################85## TO DO86self.content_loss = tf.reduce_sum((F - P) ** 2) / (4.0 * P.size)87###############################8889def _gram_matrix(self, F, N, M):90""" Create and return the gram matrix for tensor F91Hint: you'll first have to reshape F92"""93###############################94## TO DO95F = tf.reshape(F, (M, N))96return tf.matmul(tf.transpose(F), F)97###############################9899def _single_style_loss(self, a, g):100""" Calculate the style loss at a certain layer101Inputs:102a is the feature representation of the style image at that layer103g is the feature representation of the generated image at that layer104Output:105the style loss at a certain layer (which is E_l in the paper)106107Hint: 1. you'll have to use the function _gram_matrix()1082. we'll use the same coefficient for style loss as in the paper1093. a and g are feature representation, not gram matrices110"""111###############################112## TO DO113N = a.shape[3] # number of filters114M = a.shape[1] * a.shape[2] # height times width of the feature map115A = self._gram_matrix(a, N, M)116G = self._gram_matrix(g, N, M)117return tf.reduce_sum((G - A) ** 2 / ((2 * N * M) ** 2))118###############################119120def _style_loss(self, A):121""" Calculate the total style loss as a weighted sum122of style losses at all style layers123Hint: you'll have to use _single_style_loss()124"""125n_layers = len(A)126E = [self._single_style_loss(A[i], getattr(self.vgg, self.style_layers[i])) for i in range(n_layers)]127128###############################129## TO DO130self.style_loss = sum([self.style_layer_w[i] * E[i] for i in range(n_layers)])131###############################132133def losses(self):134with tf.variable_scope('losses') as scope:135with tf.Session() as sess:136# assign content image to the input variable137sess.run(self.input_img.assign(self.content_img))138gen_img_content = getattr(self.vgg, self.content_layer)139content_img_content = sess.run(gen_img_content)140self._content_loss(content_img_content, gen_img_content)141142with tf.Session() as sess:143sess.run(self.input_img.assign(self.style_img))144style_layers = sess.run([getattr(self.vgg, layer) for layer in self.style_layers])145self._style_loss(style_layers)146147##########################################148## TO DO: create total loss.149## Hint: don't forget the weights for the content loss and style loss150self.total_loss = self.content_w * self.content_loss + self.style_w * self.style_loss151##########################################152153def optimize(self):154###############################155## TO DO: create optimizer156self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.total_loss,157global_step=self.gstep)158###############################159160def create_summary(self):161###############################162## TO DO: create summaries for all the losses163## Hint: don't forget to merge them164with tf.name_scope('summaries'):165tf.summary.scalar('content loss', self.content_loss)166tf.summary.scalar('style loss', self.style_loss)167tf.summary.scalar('total loss', self.total_loss)168self.summary_op = tf.summary.merge_all()169###############################170171172def build(self):173self.create_input()174self.load_vgg()175self.losses()176self.optimize()177self.create_summary()178179def train(self, n_iters):180skip_step = 1181with tf.Session() as sess:182183###############################184## TO DO:185## 1. initialize your variables186## 2. create writer to write your graph187sess.run(tf.global_variables_initializer())188writer = tf.summary.FileWriter('graphs/style_stranfer', sess.graph)189###############################190sess.run(self.input_img.assign(self.initial_img))191192193###############################194## TO DO:195## 1. create a saver object196## 2. check if a checkpoint exists, restore the variables197saver = tf.train.Saver()198ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/style_transfer/checkpoint'))199if ckpt and ckpt.model_checkpoint_path:200saver.restore(sess, ckpt.model_checkpoint_path)201##############################202203initial_step = self.gstep.eval()204205start_time = time.time()206for index in range(initial_step, n_iters):207if index >= 5 and index < 20:208skip_step = 10209elif index >= 20:210skip_step = 20211212sess.run(self.opt)213if (index + 1) % skip_step == 0:214###############################215## TO DO: obtain generated image, loss, and summary216gen_image, total_loss, summary = sess.run([self.input_img,217self.total_loss,218self.summary_op])219220###############################221222# add back the mean pixels we subtracted before223gen_image = gen_image + self.vgg.mean_pixels224writer.add_summary(summary, global_step=index)225print('Step {}\n Sum: {:5.1f}'.format(index + 1, np.sum(gen_image)))226print(' Loss: {:5.1f}'.format(total_loss))227print(' Took: {} seconds'.format(time.time() - start_time))228start_time = time.time()229230filename = 'outputs/%d.png' % (index)231utils.save_image(filename, gen_image)232233if (index + 1) % 20 == 0:234###############################235## TO DO: save the variables into a checkpoint236saver.save(sess, 'checkpoints/style_stranfer/style_transfer', index)237###############################238239if __name__ == '__main__':240setup()241machine = StyleTransfer('content/deadpool.jpg', 'styles/guernica.jpg', 333, 250)242machine.build()243machine.train(300)244245