Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/master/Apply Generative Adversarial Networks (GANs)/Week 2 - Image-to-Image Translation with Pix2Pix/C3W2B_Assignment.ipynb
Views: 13373
Pix2Pix
Goals
In this notebook, you will write a generative model based on the paper Image-to-Image Translation with Conditional Adversarial Networks by Isola et al. 2017, also known as Pix2Pix.
You will be training a model that can convert aerial satellite imagery ("input") into map routes ("output"), as was done in the original paper. Since the architecture for the generator is a U-Net, which you've already implemented (with minor changes), the emphasis of the assignment will be on the loss function. So that you can see outputs more quickly, you'll be able to see your model train starting from a pre-trained checkpoint - but feel free to train it from scratch on your own too.
Learning Objectives
Implement the loss of a Pix2Pix model that differentiates it from a supervised U-Net.
Observe the change in generator priorities as the Pix2Pix generator trains, changing its emphasis from reconstruction to realism.
Getting Started
You will start by importing libraries, defining a visualization function, and getting the pre-trained Pix2Pix checkpoint. You will also be provided with the U-Net code for the Pix2Pix generator.
U-Net Code
The U-Net code will be much like the code you wrote for the last assignment, but with optional dropout and batchnorm. The structure is changed slightly for Pix2Pix, so that the final image is closer in size to the input image. Feel free to investigate the code if you're interested!
PatchGAN Discriminator
Next, you will define a discriminator based on the contracting path of the U-Net to allow you to evaluate the realism of the generated images. Remember that the discriminator outputs a one-channel matrix of classifications instead of a single value. Your discriminator's final layer will simply map from the final number of hidden channels to a single prediction for every pixel of the layer before it.
Training Preparation
Now you can begin putting everything together for training. You start by defining some new parameters as well as the ones you are familiar with:
real_dim: the number of channels of the real image and the number expected in the output image
adv_criterion: an adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN
recon_criterion: a loss function that rewards similar images to the ground truth, which "reconstruct" the image
lambda_recon: a parameter for how heavily the reconstruction loss should be weighed
n_epochs: the number of times you iterate through the entire dataset when training
input_dim: the number of channels of the input image
display_step: how often to display/visualize the images
batch_size: the number of images per forward/backward pass
lr: the learning rate
target_shape: the size of the output image (in pixels)
device: the device type
You will then pre-process the images of the dataset to make sure they're all the same size and that the size change due to U-Net layers is accounted for.
Next, you can initialize your generator (U-Net) and discriminator, as well as their optimizers. Finally, you will also load your pre-trained model.
While there are some changes to the U-Net architecture for Pix2Pix, the most important distinguishing feature of Pix2Pix is its adversarial loss. You will be implementing that here!
Pix2Pix Training
Finally, you can train the model and see some of your maps!
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-10-177ac0af42d2> in <module>
62 }, f"pix2pix_{cur_step}.pth")
63 cur_step += 1
---> 64 train()
<ipython-input-10-177ac0af42d2> in train(save_model)
10 for epoch in range(n_epochs):
11 # Dataloader returns the batches
---> 12 for image, _ in tqdm(dataloader):
13 image_width = image.shape[3]
14 condition = image[:, :, :, :image_width // 2]
/usr/local/lib/python3.6/dist-packages/tqdm/notebook.py in __iter__(self, *args, **kwargs)
213 def __iter__(self, *args, **kwargs):
214 try:
--> 215 for obj in super(tqdm_notebook, self).__iter__(*args, **kwargs):
216 # return super(tqdm...) will not catch exception
217 yield obj
/usr/local/lib/python3.6/dist-packages/tqdm/std.py in __iter__(self)
1102 fp_write=getattr(self.fp, 'write', sys.stderr.write))
1103
-> 1104 for obj in iterable:
1105 yield obj
1106 # Update and possibly print the progressbar.
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/folder.py in __getitem__(self, index)
138 sample = self.loader(path)
139 if self.transform is not None:
--> 140 sample = self.transform(sample)
141 if self.target_transform is not None:
142 target = self.target_transform(target)
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
68 def __call__(self, img):
69 for t in self.transforms:
---> 70 img = t(img)
71 return img
72
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, pic)
99 Tensor: Converted image.
100 """
--> 101 return F.to_tensor(pic)
102
103 def __repr__(self):
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in to_tensor(pic)
98 img = img.transpose(0, 1).transpose(0, 2).contiguous()
99 if isinstance(img, torch.ByteTensor):
--> 100 return img.float().div(255)
101 else:
102 return img
KeyboardInterrupt: