CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/dcgan_faces_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
DCGAN Tutorial
4
==============
5
6
**Author**: `Nathan Inkawhich <https://github.com/inkawhich>`__
7
8
"""
9
10
11
######################################################################
12
# Introduction
13
# ------------
14
#
15
# This tutorial will give an introduction to DCGANs through an example. We
16
# will train a generative adversarial network (GAN) to generate new
17
# celebrities after showing it pictures of many real celebrities. Most of
18
# the code here is from the DCGAN implementation in
19
# `pytorch/examples <https://github.com/pytorch/examples>`__, and this
20
# document will give a thorough explanation of the implementation and shed
21
# light on how and why this model works. But don’t worry, no prior
22
# knowledge of GANs is required, but it may require a first-timer to spend
23
# some time reasoning about what is actually happening under the hood.
24
# Also, for the sake of time it will help to have a GPU, or two. Lets
25
# start from the beginning.
26
#
27
# Generative Adversarial Networks
28
# -------------------------------
29
#
30
# What is a GAN?
31
# ~~~~~~~~~~~~~~
32
#
33
# GANs are a framework for teaching a deep learning model to capture the training
34
# data distribution so we can generate new data from that same
35
# distribution. GANs were invented by Ian Goodfellow in 2014 and first
36
# described in the paper `Generative Adversarial
37
# Nets <https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf>`__.
38
# They are made of two distinct models, a *generator* and a
39
# *discriminator*. The job of the generator is to spawn ‘fake’ images that
40
# look like the training images. The job of the discriminator is to look
41
# at an image and output whether or not it is a real training image or a
42
# fake image from the generator. During training, the generator is
43
# constantly trying to outsmart the discriminator by generating better and
44
# better fakes, while the discriminator is working to become a better
45
# detective and correctly classify the real and fake images. The
46
# equilibrium of this game is when the generator is generating perfect
47
# fakes that look as if they came directly from the training data, and the
48
# discriminator is left to always guess at 50% confidence that the
49
# generator output is real or fake.
50
#
51
# Now, lets define some notation to be used throughout tutorial starting
52
# with the discriminator. Let :math:`x` be data representing an image.
53
# :math:`D(x)` is the discriminator network which outputs the (scalar)
54
# probability that :math:`x` came from training data rather than the
55
# generator. Here, since we are dealing with images, the input to
56
# :math:`D(x)` is an image of CHW size 3x64x64. Intuitively, :math:`D(x)`
57
# should be HIGH when :math:`x` comes from training data and LOW when
58
# :math:`x` comes from the generator. :math:`D(x)` can also be thought of
59
# as a traditional binary classifier.
60
#
61
# For the generator’s notation, let :math:`z` be a latent space vector
62
# sampled from a standard normal distribution. :math:`G(z)` represents the
63
# generator function which maps the latent vector :math:`z` to data-space.
64
# The goal of :math:`G` is to estimate the distribution that the training
65
# data comes from (:math:`p_{data}`) so it can generate fake samples from
66
# that estimated distribution (:math:`p_g`).
67
#
68
# So, :math:`D(G(z))` is the probability (scalar) that the output of the
69
# generator :math:`G` is a real image. As described in `Goodfellow’s
70
# paper <https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf>`__,
71
# :math:`D` and :math:`G` play a minimax game in which :math:`D` tries to
72
# maximize the probability it correctly classifies reals and fakes
73
# (:math:`logD(x)`), and :math:`G` tries to minimize the probability that
74
# :math:`D` will predict its outputs are fake (:math:`log(1-D(G(z)))`).
75
# From the paper, the GAN loss function is
76
#
77
# .. math:: \underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big]
78
#
79
# In theory, the solution to this minimax game is where
80
# :math:`p_g = p_{data}`, and the discriminator guesses randomly if the
81
# inputs are real or fake. However, the convergence theory of GANs is
82
# still being actively researched and in reality models do not always
83
# train to this point.
84
#
85
# What is a DCGAN?
86
# ~~~~~~~~~~~~~~~~
87
#
88
# A DCGAN is a direct extension of the GAN described above, except that it
89
# explicitly uses convolutional and convolutional-transpose layers in the
90
# discriminator and generator, respectively. It was first described by
91
# Radford et. al. in the paper `Unsupervised Representation Learning With
92
# Deep Convolutional Generative Adversarial
93
# Networks <https://arxiv.org/pdf/1511.06434.pdf>`__. The discriminator
94
# is made up of strided
95
# `convolution <https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d>`__
96
# layers, `batch
97
# norm <https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm2d>`__
98
# layers, and
99
# `LeakyReLU <https://pytorch.org/docs/stable/nn.html#torch.nn.LeakyReLU>`__
100
# activations. The input is a 3x64x64 input image and the output is a
101
# scalar probability that the input is from the real data distribution.
102
# The generator is comprised of
103
# `convolutional-transpose <https://pytorch.org/docs/stable/nn.html#torch.nn.ConvTranspose2d>`__
104
# layers, batch norm layers, and
105
# `ReLU <https://pytorch.org/docs/stable/nn.html#relu>`__ activations. The
106
# input is a latent vector, :math:`z`, that is drawn from a standard
107
# normal distribution and the output is a 3x64x64 RGB image. The strided
108
# conv-transpose layers allow the latent vector to be transformed into a
109
# volume with the same shape as an image. In the paper, the authors also
110
# give some tips about how to setup the optimizers, how to calculate the
111
# loss functions, and how to initialize the model weights, all of which
112
# will be explained in the coming sections.
113
#
114
115
#%matplotlib inline
116
import argparse
117
import os
118
import random
119
import torch
120
import torch.nn as nn
121
import torch.nn.parallel
122
import torch.optim as optim
123
import torch.utils.data
124
import torchvision.datasets as dset
125
import torchvision.transforms as transforms
126
import torchvision.utils as vutils
127
import numpy as np
128
import matplotlib.pyplot as plt
129
import matplotlib.animation as animation
130
from IPython.display import HTML
131
132
# Set random seed for reproducibility
133
manualSeed = 999
134
#manualSeed = random.randint(1, 10000) # use if you want new results
135
print("Random Seed: ", manualSeed)
136
random.seed(manualSeed)
137
torch.manual_seed(manualSeed)
138
torch.use_deterministic_algorithms(True) # Needed for reproducible results
139
140
141
######################################################################
142
# Inputs
143
# ------
144
#
145
# Let’s define some inputs for the run:
146
#
147
# - ``dataroot`` - the path to the root of the dataset folder. We will
148
# talk more about the dataset in the next section.
149
# - ``workers`` - the number of worker threads for loading the data with
150
# the ``DataLoader``.
151
# - ``batch_size`` - the batch size used in training. The DCGAN paper
152
# uses a batch size of 128.
153
# - ``image_size`` - the spatial size of the images used for training.
154
# This implementation defaults to 64x64. If another size is desired,
155
# the structures of D and G must be changed. See
156
# `here <https://github.com/pytorch/examples/issues/70>`__ for more
157
# details.
158
# - ``nc`` - number of color channels in the input images. For color
159
# images this is 3.
160
# - ``nz`` - length of latent vector.
161
# - ``ngf`` - relates to the depth of feature maps carried through the
162
# generator.
163
# - ``ndf`` - sets the depth of feature maps propagated through the
164
# discriminator.
165
# - ``num_epochs`` - number of training epochs to run. Training for
166
# longer will probably lead to better results but will also take much
167
# longer.
168
# - ``lr`` - learning rate for training. As described in the DCGAN paper,
169
# this number should be 0.0002.
170
# - ``beta1`` - beta1 hyperparameter for Adam optimizers. As described in
171
# paper, this number should be 0.5.
172
# - ``ngpu`` - number of GPUs available. If this is 0, code will run in
173
# CPU mode. If this number is greater than 0 it will run on that number
174
# of GPUs.
175
#
176
177
# Root directory for dataset
178
dataroot = "data/celeba"
179
180
# Number of workers for dataloader
181
workers = 2
182
183
# Batch size during training
184
batch_size = 128
185
186
# Spatial size of training images. All images will be resized to this
187
# size using a transformer.
188
image_size = 64
189
190
# Number of channels in the training images. For color images this is 3
191
nc = 3
192
193
# Size of z latent vector (i.e. size of generator input)
194
nz = 100
195
196
# Size of feature maps in generator
197
ngf = 64
198
199
# Size of feature maps in discriminator
200
ndf = 64
201
202
# Number of training epochs
203
num_epochs = 5
204
205
# Learning rate for optimizers
206
lr = 0.0002
207
208
# Beta1 hyperparameter for Adam optimizers
209
beta1 = 0.5
210
211
# Number of GPUs available. Use 0 for CPU mode.
212
ngpu = 1
213
214
215
######################################################################
216
# Data
217
# ----
218
#
219
# In this tutorial we will use the `Celeb-A Faces
220
# dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`__ which can
221
# be downloaded at the linked site, or in `Google
222
# Drive <https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg>`__.
223
# The dataset will download as a file named ``img_align_celeba.zip``. Once
224
# downloaded, create a directory named ``celeba`` and extract the zip file
225
# into that directory. Then, set the ``dataroot`` input for this notebook to
226
# the ``celeba`` directory you just created. The resulting directory
227
# structure should be:
228
#
229
# .. code-block:: sh
230
#
231
# /path/to/celeba
232
# -> img_align_celeba
233
# -> 188242.jpg
234
# -> 173822.jpg
235
# -> 284702.jpg
236
# -> 537394.jpg
237
# ...
238
#
239
# This is an important step because we will be using the ``ImageFolder``
240
# dataset class, which requires there to be subdirectories in the
241
# dataset root folder. Now, we can create the dataset, create the
242
# dataloader, set the device to run on, and finally visualize some of the
243
# training data.
244
#
245
246
# We can use an image folder dataset the way we have it setup.
247
# Create the dataset
248
dataset = dset.ImageFolder(root=dataroot,
249
transform=transforms.Compose([
250
transforms.Resize(image_size),
251
transforms.CenterCrop(image_size),
252
transforms.ToTensor(),
253
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
254
]))
255
# Create the dataloader
256
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
257
shuffle=True, num_workers=workers)
258
259
# Decide which device we want to run on
260
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
261
262
# Plot some training images
263
real_batch = next(iter(dataloader))
264
plt.figure(figsize=(8,8))
265
plt.axis("off")
266
plt.title("Training Images")
267
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
268
plt.show()
269
270
271
######################################################################
272
# Implementation
273
# --------------
274
#
275
# With our input parameters set and the dataset prepared, we can now get
276
# into the implementation. We will start with the weight initialization
277
# strategy, then talk about the generator, discriminator, loss functions,
278
# and training loop in detail.
279
#
280
# Weight Initialization
281
# ~~~~~~~~~~~~~~~~~~~~~
282
#
283
# From the DCGAN paper, the authors specify that all model weights shall
284
# be randomly initialized from a Normal distribution with ``mean=0``,
285
# ``stdev=0.02``. The ``weights_init`` function takes an initialized model as
286
# input and reinitializes all convolutional, convolutional-transpose, and
287
# batch normalization layers to meet this criteria. This function is
288
# applied to the models immediately after initialization.
289
#
290
291
# custom weights initialization called on ``netG`` and ``netD``
292
def weights_init(m):
293
classname = m.__class__.__name__
294
if classname.find('Conv') != -1:
295
nn.init.normal_(m.weight.data, 0.0, 0.02)
296
elif classname.find('BatchNorm') != -1:
297
nn.init.normal_(m.weight.data, 1.0, 0.02)
298
nn.init.constant_(m.bias.data, 0)
299
300
301
######################################################################
302
# Generator
303
# ~~~~~~~~~
304
#
305
# The generator, :math:`G`, is designed to map the latent space vector
306
# (:math:`z`) to data-space. Since our data are images, converting
307
# :math:`z` to data-space means ultimately creating a RGB image with the
308
# same size as the training images (i.e. 3x64x64). In practice, this is
309
# accomplished through a series of strided two dimensional convolutional
310
# transpose layers, each paired with a 2d batch norm layer and a relu
311
# activation. The output of the generator is fed through a tanh function
312
# to return it to the input data range of :math:`[-1,1]`. It is worth
313
# noting the existence of the batch norm functions after the
314
# conv-transpose layers, as this is a critical contribution of the DCGAN
315
# paper. These layers help with the flow of gradients during training. An
316
# image of the generator from the DCGAN paper is shown below.
317
#
318
# .. figure:: /_static/img/dcgan_generator.png
319
# :alt: dcgan_generator
320
#
321
# Notice, how the inputs we set in the input section (``nz``, ``ngf``, and
322
# ``nc``) influence the generator architecture in code. ``nz`` is the length
323
# of the z input vector, ``ngf`` relates to the size of the feature maps
324
# that are propagated through the generator, and ``nc`` is the number of
325
# channels in the output image (set to 3 for RGB images). Below is the
326
# code for the generator.
327
#
328
329
# Generator Code
330
331
class Generator(nn.Module):
332
def __init__(self, ngpu):
333
super(Generator, self).__init__()
334
self.ngpu = ngpu
335
self.main = nn.Sequential(
336
# input is Z, going into a convolution
337
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
338
nn.BatchNorm2d(ngf * 8),
339
nn.ReLU(True),
340
# state size. ``(ngf*8) x 4 x 4``
341
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
342
nn.BatchNorm2d(ngf * 4),
343
nn.ReLU(True),
344
# state size. ``(ngf*4) x 8 x 8``
345
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
346
nn.BatchNorm2d(ngf * 2),
347
nn.ReLU(True),
348
# state size. ``(ngf*2) x 16 x 16``
349
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
350
nn.BatchNorm2d(ngf),
351
nn.ReLU(True),
352
# state size. ``(ngf) x 32 x 32``
353
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
354
nn.Tanh()
355
# state size. ``(nc) x 64 x 64``
356
)
357
358
def forward(self, input):
359
return self.main(input)
360
361
362
######################################################################
363
# Now, we can instantiate the generator and apply the ``weights_init``
364
# function. Check out the printed model to see how the generator object is
365
# structured.
366
#
367
368
# Create the generator
369
netG = Generator(ngpu).to(device)
370
371
# Handle multi-GPU if desired
372
if (device.type == 'cuda') and (ngpu > 1):
373
netG = nn.DataParallel(netG, list(range(ngpu)))
374
375
# Apply the ``weights_init`` function to randomly initialize all weights
376
# to ``mean=0``, ``stdev=0.02``.
377
netG.apply(weights_init)
378
379
# Print the model
380
print(netG)
381
382
383
######################################################################
384
# Discriminator
385
# ~~~~~~~~~~~~~
386
#
387
# As mentioned, the discriminator, :math:`D`, is a binary classification
388
# network that takes an image as input and outputs a scalar probability
389
# that the input image is real (as opposed to fake). Here, :math:`D` takes
390
# a 3x64x64 input image, processes it through a series of Conv2d,
391
# BatchNorm2d, and LeakyReLU layers, and outputs the final probability
392
# through a Sigmoid activation function. This architecture can be extended
393
# with more layers if necessary for the problem, but there is significance
394
# to the use of the strided convolution, BatchNorm, and LeakyReLUs. The
395
# DCGAN paper mentions it is a good practice to use strided convolution
396
# rather than pooling to downsample because it lets the network learn its
397
# own pooling function. Also batch norm and leaky relu functions promote
398
# healthy gradient flow which is critical for the learning process of both
399
# :math:`G` and :math:`D`.
400
#
401
402
#########################################################################
403
# Discriminator Code
404
405
class Discriminator(nn.Module):
406
def __init__(self, ngpu):
407
super(Discriminator, self).__init__()
408
self.ngpu = ngpu
409
self.main = nn.Sequential(
410
# input is ``(nc) x 64 x 64``
411
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
412
nn.LeakyReLU(0.2, inplace=True),
413
# state size. ``(ndf) x 32 x 32``
414
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
415
nn.BatchNorm2d(ndf * 2),
416
nn.LeakyReLU(0.2, inplace=True),
417
# state size. ``(ndf*2) x 16 x 16``
418
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
419
nn.BatchNorm2d(ndf * 4),
420
nn.LeakyReLU(0.2, inplace=True),
421
# state size. ``(ndf*4) x 8 x 8``
422
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
423
nn.BatchNorm2d(ndf * 8),
424
nn.LeakyReLU(0.2, inplace=True),
425
# state size. ``(ndf*8) x 4 x 4``
426
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
427
nn.Sigmoid()
428
)
429
430
def forward(self, input):
431
return self.main(input)
432
433
434
######################################################################
435
# Now, as with the generator, we can create the discriminator, apply the
436
# ``weights_init`` function, and print the model’s structure.
437
#
438
439
# Create the Discriminator
440
netD = Discriminator(ngpu).to(device)
441
442
# Handle multi-GPU if desired
443
if (device.type == 'cuda') and (ngpu > 1):
444
netD = nn.DataParallel(netD, list(range(ngpu)))
445
446
# Apply the ``weights_init`` function to randomly initialize all weights
447
# like this: ``to mean=0, stdev=0.2``.
448
netD.apply(weights_init)
449
450
# Print the model
451
print(netD)
452
453
454
######################################################################
455
# Loss Functions and Optimizers
456
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
457
#
458
# With :math:`D` and :math:`G` setup, we can specify how they learn
459
# through the loss functions and optimizers. We will use the Binary Cross
460
# Entropy loss
461
# (`BCELoss <https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss>`__)
462
# function which is defined in PyTorch as:
463
#
464
# .. math:: \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
465
#
466
# Notice how this function provides the calculation of both log components
467
# in the objective function (i.e. :math:`log(D(x))` and
468
# :math:`log(1-D(G(z)))`). We can specify what part of the BCE equation to
469
# use with the :math:`y` input. This is accomplished in the training loop
470
# which is coming up soon, but it is important to understand how we can
471
# choose which component we wish to calculate just by changing :math:`y`
472
# (i.e. GT labels).
473
#
474
# Next, we define our real label as 1 and the fake label as 0. These
475
# labels will be used when calculating the losses of :math:`D` and
476
# :math:`G`, and this is also the convention used in the original GAN
477
# paper. Finally, we set up two separate optimizers, one for :math:`D` and
478
# one for :math:`G`. As specified in the DCGAN paper, both are Adam
479
# optimizers with learning rate 0.0002 and Beta1 = 0.5. For keeping track
480
# of the generator’s learning progression, we will generate a fixed batch
481
# of latent vectors that are drawn from a Gaussian distribution
482
# (i.e. fixed_noise) . In the training loop, we will periodically input
483
# this fixed_noise into :math:`G`, and over the iterations we will see
484
# images form out of the noise.
485
#
486
487
# Initialize the ``BCELoss`` function
488
criterion = nn.BCELoss()
489
490
# Create batch of latent vectors that we will use to visualize
491
# the progression of the generator
492
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
493
494
# Establish convention for real and fake labels during training
495
real_label = 1.
496
fake_label = 0.
497
498
# Setup Adam optimizers for both G and D
499
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
500
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
501
502
503
######################################################################
504
# Training
505
# ~~~~~~~~
506
#
507
# Finally, now that we have all of the parts of the GAN framework defined,
508
# we can train it. Be mindful that training GANs is somewhat of an art
509
# form, as incorrect hyperparameter settings lead to mode collapse with
510
# little explanation of what went wrong. Here, we will closely follow
511
# Algorithm 1 from the `Goodfellow’s paper <https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf>`__,
512
# while abiding by some of the best
513
# practices shown in `ganhacks <https://github.com/soumith/ganhacks>`__.
514
# Namely, we will “construct different mini-batches for real and fake”
515
# images, and also adjust G’s objective function to maximize
516
# :math:`log(D(G(z)))`. Training is split up into two main parts. Part 1
517
# updates the Discriminator and Part 2 updates the Generator.
518
#
519
# **Part 1 - Train the Discriminator**
520
#
521
# Recall, the goal of training the discriminator is to maximize the
522
# probability of correctly classifying a given input as real or fake. In
523
# terms of Goodfellow, we wish to “update the discriminator by ascending
524
# its stochastic gradient”. Practically, we want to maximize
525
# :math:`log(D(x)) + log(1-D(G(z)))`. Due to the separate mini-batch
526
# suggestion from `ganhacks <https://github.com/soumith/ganhacks>`__,
527
# we will calculate this in two steps. First, we
528
# will construct a batch of real samples from the training set, forward
529
# pass through :math:`D`, calculate the loss (:math:`log(D(x))`), then
530
# calculate the gradients in a backward pass. Secondly, we will construct
531
# a batch of fake samples with the current generator, forward pass this
532
# batch through :math:`D`, calculate the loss (:math:`log(1-D(G(z)))`),
533
# and *accumulate* the gradients with a backward pass. Now, with the
534
# gradients accumulated from both the all-real and all-fake batches, we
535
# call a step of the Discriminator’s optimizer.
536
#
537
# **Part 2 - Train the Generator**
538
#
539
# As stated in the original paper, we want to train the Generator by
540
# minimizing :math:`log(1-D(G(z)))` in an effort to generate better fakes.
541
# As mentioned, this was shown by Goodfellow to not provide sufficient
542
# gradients, especially early in the learning process. As a fix, we
543
# instead wish to maximize :math:`log(D(G(z)))`. In the code we accomplish
544
# this by: classifying the Generator output from Part 1 with the
545
# Discriminator, computing G’s loss *using real labels as GT*, computing
546
# G’s gradients in a backward pass, and finally updating G’s parameters
547
# with an optimizer step. It may seem counter-intuitive to use the real
548
# labels as GT labels for the loss function, but this allows us to use the
549
# :math:`log(x)` part of the ``BCELoss`` (rather than the :math:`log(1-x)`
550
# part) which is exactly what we want.
551
#
552
# Finally, we will do some statistic reporting and at the end of each
553
# epoch we will push our fixed_noise batch through the generator to
554
# visually track the progress of G’s training. The training statistics
555
# reported are:
556
#
557
# - **Loss_D** - discriminator loss calculated as the sum of losses for
558
# the all real and all fake batches (:math:`log(D(x)) + log(1 - D(G(z)))`).
559
# - **Loss_G** - generator loss calculated as :math:`log(D(G(z)))`
560
# - **D(x)** - the average output (across the batch) of the discriminator
561
# for the all real batch. This should start close to 1 then
562
# theoretically converge to 0.5 when G gets better. Think about why
563
# this is.
564
# - **D(G(z))** - average discriminator outputs for the all fake batch.
565
# The first number is before D is updated and the second number is
566
# after D is updated. These numbers should start near 0 and converge to
567
# 0.5 as G gets better. Think about why this is.
568
#
569
# **Note:** This step might take a while, depending on how many epochs you
570
# run and if you removed some data from the dataset.
571
#
572
573
# Training Loop
574
575
# Lists to keep track of progress
576
img_list = []
577
G_losses = []
578
D_losses = []
579
iters = 0
580
581
print("Starting Training Loop...")
582
# For each epoch
583
for epoch in range(num_epochs):
584
# For each batch in the dataloader
585
for i, data in enumerate(dataloader, 0):
586
587
############################
588
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
589
###########################
590
## Train with all-real batch
591
netD.zero_grad()
592
# Format batch
593
real_cpu = data[0].to(device)
594
b_size = real_cpu.size(0)
595
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
596
# Forward pass real batch through D
597
output = netD(real_cpu).view(-1)
598
# Calculate loss on all-real batch
599
errD_real = criterion(output, label)
600
# Calculate gradients for D in backward pass
601
errD_real.backward()
602
D_x = output.mean().item()
603
604
## Train with all-fake batch
605
# Generate batch of latent vectors
606
noise = torch.randn(b_size, nz, 1, 1, device=device)
607
# Generate fake image batch with G
608
fake = netG(noise)
609
label.fill_(fake_label)
610
# Classify all fake batch with D
611
output = netD(fake.detach()).view(-1)
612
# Calculate D's loss on the all-fake batch
613
errD_fake = criterion(output, label)
614
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
615
errD_fake.backward()
616
D_G_z1 = output.mean().item()
617
# Compute error of D as sum over the fake and the real batches
618
errD = errD_real + errD_fake
619
# Update D
620
optimizerD.step()
621
622
############################
623
# (2) Update G network: maximize log(D(G(z)))
624
###########################
625
netG.zero_grad()
626
label.fill_(real_label) # fake labels are real for generator cost
627
# Since we just updated D, perform another forward pass of all-fake batch through D
628
output = netD(fake).view(-1)
629
# Calculate G's loss based on this output
630
errG = criterion(output, label)
631
# Calculate gradients for G
632
errG.backward()
633
D_G_z2 = output.mean().item()
634
# Update G
635
optimizerG.step()
636
637
# Output training stats
638
if i % 50 == 0:
639
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
640
% (epoch, num_epochs, i, len(dataloader),
641
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
642
643
# Save Losses for plotting later
644
G_losses.append(errG.item())
645
D_losses.append(errD.item())
646
647
# Check how the generator is doing by saving G's output on fixed_noise
648
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
649
with torch.no_grad():
650
fake = netG(fixed_noise).detach().cpu()
651
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
652
653
iters += 1
654
655
656
######################################################################
657
# Results
658
# -------
659
#
660
# Finally, lets check out how we did. Here, we will look at three
661
# different results. First, we will see how D and G’s losses changed
662
# during training. Second, we will visualize G’s output on the fixed_noise
663
# batch for every epoch. And third, we will look at a batch of real data
664
# next to a batch of fake data from G.
665
#
666
# **Loss versus training iteration**
667
#
668
# Below is a plot of D & G’s losses versus training iterations.
669
#
670
671
plt.figure(figsize=(10,5))
672
plt.title("Generator and Discriminator Loss During Training")
673
plt.plot(G_losses,label="G")
674
plt.plot(D_losses,label="D")
675
plt.xlabel("iterations")
676
plt.ylabel("Loss")
677
plt.legend()
678
plt.show()
679
680
681
######################################################################
682
# **Visualization of G’s progression**
683
#
684
# Remember how we saved the generator’s output on the fixed_noise batch
685
# after every epoch of training. Now, we can visualize the training
686
# progression of G with an animation. Press the play button to start the
687
# animation.
688
#
689
690
#%%capture
691
fig = plt.figure(figsize=(8,8))
692
plt.axis("off")
693
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
694
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
695
696
HTML(ani.to_jshtml())
697
698
699
######################################################################
700
# **Real Images vs. Fake Images**
701
#
702
# Finally, lets take a look at some real images and fake images side by
703
# side.
704
#
705
706
# Grab a batch of real images from the dataloader
707
real_batch = next(iter(dataloader))
708
709
# Plot the real images
710
plt.figure(figsize=(15,15))
711
plt.subplot(1,2,1)
712
plt.axis("off")
713
plt.title("Real Images")
714
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
715
716
# Plot the fake images from the last epoch
717
plt.subplot(1,2,2)
718
plt.axis("off")
719
plt.title("Fake Images")
720
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
721
plt.show()
722
723
724
######################################################################
725
# Where to Go Next
726
# ----------------
727
#
728
# We have reached the end of our journey, but there are several places you
729
# could go from here. You could:
730
#
731
# - Train for longer to see how good the results get
732
# - Modify this model to take a different dataset and possibly change the
733
# size of the images and the model architecture
734
# - Check out some other cool GAN projects
735
# `here <https://github.com/nashory/gans-awesome-applications>`__
736
# - Create GANs that generate
737
# `music <https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio/>`__
738
#
739
740
741