CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/spatial_transformer_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Spatial Transformer Networks Tutorial
4
=====================================
5
**Author**: `Ghassen HAMROUNI <https://github.com/GHamrouni>`_
6
7
.. figure:: /_static/img/stn/FSeq.png
8
9
In this tutorial, you will learn how to augment your network using
10
a visual attention mechanism called spatial transformer
11
networks. You can read more about the spatial transformer
12
networks in the `DeepMind paper <https://arxiv.org/abs/1506.02025>`__
13
14
Spatial transformer networks are a generalization of differentiable
15
attention to any spatial transformation. Spatial transformer networks
16
(STN for short) allow a neural network to learn how to perform spatial
17
transformations on the input image in order to enhance the geometric
18
invariance of the model.
19
For example, it can crop a region of interest, scale and correct
20
the orientation of an image. It can be a useful mechanism because CNNs
21
are not invariant to rotation and scale and more general affine
22
transformations.
23
24
One of the best things about STN is the ability to simply plug it into
25
any existing CNN with very little modification.
26
"""
27
# License: BSD
28
# Author: Ghassen Hamrouni
29
30
import torch
31
import torch.nn as nn
32
import torch.nn.functional as F
33
import torch.optim as optim
34
import torchvision
35
from torchvision import datasets, transforms
36
import matplotlib.pyplot as plt
37
import numpy as np
38
39
plt.ion() # interactive mode
40
41
######################################################################
42
# Loading the data
43
# ----------------
44
#
45
# In this post we experiment with the classic MNIST dataset. Using a
46
# standard convolutional network augmented with a spatial transformer
47
# network.
48
49
from six.moves import urllib
50
opener = urllib.request.build_opener()
51
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
52
urllib.request.install_opener(opener)
53
54
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
56
# Training dataset
57
train_loader = torch.utils.data.DataLoader(
58
datasets.MNIST(root='.', train=True, download=True,
59
transform=transforms.Compose([
60
transforms.ToTensor(),
61
transforms.Normalize((0.1307,), (0.3081,))
62
])), batch_size=64, shuffle=True, num_workers=4)
63
# Test dataset
64
test_loader = torch.utils.data.DataLoader(
65
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
66
transforms.ToTensor(),
67
transforms.Normalize((0.1307,), (0.3081,))
68
])), batch_size=64, shuffle=True, num_workers=4)
69
70
######################################################################
71
# Depicting spatial transformer networks
72
# --------------------------------------
73
#
74
# Spatial transformer networks boils down to three main components :
75
#
76
# - The localization network is a regular CNN which regresses the
77
# transformation parameters. The transformation is never learned
78
# explicitly from this dataset, instead the network learns automatically
79
# the spatial transformations that enhances the global accuracy.
80
# - The grid generator generates a grid of coordinates in the input
81
# image corresponding to each pixel from the output image.
82
# - The sampler uses the parameters of the transformation and applies
83
# it to the input image.
84
#
85
# .. figure:: /_static/img/stn/stn-arch.png
86
#
87
# .. note::
88
# We need the latest version of PyTorch that contains
89
# affine_grid and grid_sample modules.
90
#
91
92
93
class Net(nn.Module):
94
def __init__(self):
95
super(Net, self).__init__()
96
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
97
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
98
self.conv2_drop = nn.Dropout2d()
99
self.fc1 = nn.Linear(320, 50)
100
self.fc2 = nn.Linear(50, 10)
101
102
# Spatial transformer localization-network
103
self.localization = nn.Sequential(
104
nn.Conv2d(1, 8, kernel_size=7),
105
nn.MaxPool2d(2, stride=2),
106
nn.ReLU(True),
107
nn.Conv2d(8, 10, kernel_size=5),
108
nn.MaxPool2d(2, stride=2),
109
nn.ReLU(True)
110
)
111
112
# Regressor for the 3 * 2 affine matrix
113
self.fc_loc = nn.Sequential(
114
nn.Linear(10 * 3 * 3, 32),
115
nn.ReLU(True),
116
nn.Linear(32, 3 * 2)
117
)
118
119
# Initialize the weights/bias with identity transformation
120
self.fc_loc[2].weight.data.zero_()
121
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
122
123
# Spatial transformer network forward function
124
def stn(self, x):
125
xs = self.localization(x)
126
xs = xs.view(-1, 10 * 3 * 3)
127
theta = self.fc_loc(xs)
128
theta = theta.view(-1, 2, 3)
129
130
grid = F.affine_grid(theta, x.size())
131
x = F.grid_sample(x, grid)
132
133
return x
134
135
def forward(self, x):
136
# transform the input
137
x = self.stn(x)
138
139
# Perform the usual forward pass
140
x = F.relu(F.max_pool2d(self.conv1(x), 2))
141
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
142
x = x.view(-1, 320)
143
x = F.relu(self.fc1(x))
144
x = F.dropout(x, training=self.training)
145
x = self.fc2(x)
146
return F.log_softmax(x, dim=1)
147
148
149
model = Net().to(device)
150
151
######################################################################
152
# Training the model
153
# ------------------
154
#
155
# Now, let's use the SGD algorithm to train the model. The network is
156
# learning the classification task in a supervised way. In the same time
157
# the model is learning STN automatically in an end-to-end fashion.
158
159
160
optimizer = optim.SGD(model.parameters(), lr=0.01)
161
162
163
def train(epoch):
164
model.train()
165
for batch_idx, (data, target) in enumerate(train_loader):
166
data, target = data.to(device), target.to(device)
167
168
optimizer.zero_grad()
169
output = model(data)
170
loss = F.nll_loss(output, target)
171
loss.backward()
172
optimizer.step()
173
if batch_idx % 500 == 0:
174
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
175
epoch, batch_idx * len(data), len(train_loader.dataset),
176
100. * batch_idx / len(train_loader), loss.item()))
177
#
178
# A simple test procedure to measure the STN performances on MNIST.
179
#
180
181
182
def test():
183
with torch.no_grad():
184
model.eval()
185
test_loss = 0
186
correct = 0
187
for data, target in test_loader:
188
data, target = data.to(device), target.to(device)
189
output = model(data)
190
191
# sum up batch loss
192
test_loss += F.nll_loss(output, target, size_average=False).item()
193
# get the index of the max log-probability
194
pred = output.max(1, keepdim=True)[1]
195
correct += pred.eq(target.view_as(pred)).sum().item()
196
197
test_loss /= len(test_loader.dataset)
198
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
199
.format(test_loss, correct, len(test_loader.dataset),
200
100. * correct / len(test_loader.dataset)))
201
202
######################################################################
203
# Visualizing the STN results
204
# ---------------------------
205
#
206
# Now, we will inspect the results of our learned visual attention
207
# mechanism.
208
#
209
# We define a small helper function in order to visualize the
210
# transformations while training.
211
212
213
def convert_image_np(inp):
214
"""Convert a Tensor to numpy image."""
215
inp = inp.numpy().transpose((1, 2, 0))
216
mean = np.array([0.485, 0.456, 0.406])
217
std = np.array([0.229, 0.224, 0.225])
218
inp = std * inp + mean
219
inp = np.clip(inp, 0, 1)
220
return inp
221
222
# We want to visualize the output of the spatial transformers layer
223
# after the training, we visualize a batch of input images and
224
# the corresponding transformed batch using STN.
225
226
227
def visualize_stn():
228
with torch.no_grad():
229
# Get a batch of training data
230
data = next(iter(test_loader))[0].to(device)
231
232
input_tensor = data.cpu()
233
transformed_input_tensor = model.stn(data).cpu()
234
235
in_grid = convert_image_np(
236
torchvision.utils.make_grid(input_tensor))
237
238
out_grid = convert_image_np(
239
torchvision.utils.make_grid(transformed_input_tensor))
240
241
# Plot the results side-by-side
242
f, axarr = plt.subplots(1, 2)
243
axarr[0].imshow(in_grid)
244
axarr[0].set_title('Dataset Images')
245
246
axarr[1].imshow(out_grid)
247
axarr[1].set_title('Transformed Images')
248
249
for epoch in range(1, 20 + 1):
250
train(epoch)
251
test()
252
253
# Visualize the STN transformation on some input batch
254
visualize_stn()
255
256
plt.ioff()
257
plt.show()
258
259