CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/custom_function_conv_bn_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Fusing Convolution and Batch Norm using Custom Function
4
=======================================================
5
6
Fusing adjacent convolution and batch norm layers together is typically an
7
inference-time optimization to improve run-time. It is usually achieved
8
by eliminating the batch norm layer entirely and updating the weight
9
and bias of the preceding convolution [0]. However, this technique is not
10
applicable for training models.
11
12
In this tutorial, we will show a different technique to fuse the two layers
13
that can be applied during training. Rather than improved runtime, the
14
objective of this optimization is to reduce memory usage.
15
16
The idea behind this optimization is to see that both convolution and
17
batch norm (as well as many other ops) need to save a copy of their input
18
during forward for the backward pass. For large
19
batch sizes, these saved inputs are responsible for most of your memory usage,
20
so being able to avoid allocating another input tensor for every
21
convolution batch norm pair can be a significant reduction.
22
23
In this tutorial, we avoid this extra allocation by combining convolution
24
and batch norm into a single layer (as a custom function). In the forward
25
of this combined layer, we perform normal convolution and batch norm as-is,
26
with the only difference being that we will only save the inputs to the convolution.
27
To obtain the input of batch norm, which is necessary to backward through
28
it, we recompute convolution forward again during the backward pass.
29
30
It is important to note that the usage of this optimization is situational.
31
Though (by avoiding one buffer saved) we always reduce the memory allocated at
32
the end of the forward pass, there are cases when the *peak* memory allocated
33
may not actually be reduced. See the final section for more details.
34
35
For simplicity, in this tutorial we hardcode `bias=False`, `stride=1`, `padding=0`, `dilation=1`,
36
and `groups=1` for Conv2D. For BatchNorm2D, we hardcode `eps=1e-3`, `momentum=0.1`,
37
`affine=False`, and `track_running_statistics=False`. Another small difference
38
is that we add epsilon in the denominator outside of the square root in the computation
39
of batch norm.
40
41
[0] https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
42
"""
43
44
######################################################################
45
# Backward Formula Implementation for Convolution
46
# -------------------------------------------------------------------
47
# Implementing a custom function requires us to implement the backward
48
# ourselves. In this case, we need both the backward formulas for Conv2D
49
# and BatchNorm2D. Eventually we'd chain them together in our unified
50
# backward function, but below we first implement them as their own
51
# custom functions so we can validate their correctness individually
52
import torch
53
from torch.autograd.function import once_differentiable
54
import torch.nn.functional as F
55
56
def convolution_backward(grad_out, X, weight):
57
grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)
58
grad_X = F.conv_transpose2d(grad_out, weight)
59
return grad_X, grad_input
60
61
class Conv2D(torch.autograd.Function):
62
@staticmethod
63
def forward(ctx, X, weight):
64
ctx.save_for_backward(X, weight)
65
return F.conv2d(X, weight)
66
67
# Use @once_differentiable by default unless we intend to double backward
68
@staticmethod
69
@once_differentiable
70
def backward(ctx, grad_out):
71
X, weight = ctx.saved_tensors
72
return convolution_backward(grad_out, X, weight)
73
74
######################################################################
75
# When testing with ``gradcheck``, it is important to use double precision
76
weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
77
X = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)
78
torch.autograd.gradcheck(Conv2D.apply, (X, weight))
79
80
######################################################################
81
# Backward Formula Implementation for Batch Norm
82
# -------------------------------------------------------------------
83
# Batch Norm has two modes: training and ``eval`` mode. In training mode
84
# the sample statistics are a function of the inputs. In ``eval`` mode,
85
# we use the saved running statistics, which are not a function of the inputs.
86
# This makes non-training mode's backward significantly simpler. Below
87
# we implement and test only the training mode case.
88
def unsqueeze_all(t):
89
# Helper function to ``unsqueeze`` all the dimensions that we reduce over
90
return t[None, :, None, None]
91
92
def batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps):
93
# We use the formula: ``out = (X - mean(X)) / (sqrt(var(X)) + eps)``
94
# in batch norm 2D forward. To simplify our derivation, we follow the
95
# chain rule and compute the gradients as follows before accumulating
96
# them all into a final grad_input.
97
# 1) ``grad of out wrt var(X)`` * ``grad of var(X) wrt X``
98
# 2) ``grad of out wrt mean(X)`` * ``grad of mean(X) wrt X``
99
# 3) ``grad of out wrt X in the numerator`` * ``grad of X wrt X``
100
# We then rewrite the formulas to use as few extra buffers as possible
101
tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3))
102
tmp *= -1
103
d_denom = tmp / (sqrt_var + eps)**2 # ``d_denom = -num / denom**2``
104
# It is useful to delete tensors when you no longer need them with ``del``
105
# For example, we could've done ``del tmp`` here because we won't use it later
106
# In this case, it's not a big difference because ``tmp`` only has size of (C,)
107
# The important thing is avoid allocating NCHW-sized tensors unnecessarily
108
d_var = d_denom / (2 * sqrt_var) # ``denom = torch.sqrt(var) + eps``
109
# Compute ``d_mean_dx`` before allocating the final NCHW-sized grad_input buffer
110
d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps)
111
d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N)
112
# ``d_mean_dx`` has already been reassigned to a C-sized buffer so no need to worry
113
114
# ``(1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1)``
115
grad_input = X * unsqueeze_all(d_var * N)
116
grad_input += unsqueeze_all(-d_var * sum)
117
grad_input *= 2 / ((N - 1) * N)
118
# (2) mean (see above)
119
grad_input += d_mean_dx
120
# (3) Add 'grad_out / <factor>' without allocating an extra buffer
121
grad_input *= unsqueeze_all(sqrt_var + eps)
122
grad_input += grad_out
123
grad_input /= unsqueeze_all(sqrt_var + eps) # ``sqrt_var + eps > 0!``
124
return grad_input
125
126
class BatchNorm(torch.autograd.Function):
127
@staticmethod
128
def forward(ctx, X, eps=1e-3):
129
# Don't save ``keepdim`` values for backward
130
sum = X.sum(dim=(0, 2, 3))
131
var = X.var(unbiased=True, dim=(0, 2, 3))
132
N = X.numel() / X.size(1)
133
sqrt_var = torch.sqrt(var)
134
ctx.save_for_backward(X)
135
ctx.eps = eps
136
ctx.sum = sum
137
ctx.N = N
138
ctx.sqrt_var = sqrt_var
139
mean = sum / N
140
denom = sqrt_var + eps
141
out = X - unsqueeze_all(mean)
142
out /= unsqueeze_all(denom)
143
return out
144
145
@staticmethod
146
@once_differentiable
147
def backward(ctx, grad_out):
148
X, = ctx.saved_tensors
149
return batch_norm_backward(grad_out, X, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps)
150
151
######################################################################
152
# Testing with ``gradcheck``
153
a = torch.rand(1, 2, 3, 4, requires_grad=True, dtype=torch.double)
154
torch.autograd.gradcheck(BatchNorm.apply, (a,), fast_mode=False)
155
156
######################################################################
157
# Fusing Convolution and BatchNorm
158
# -------------------------------------------------------------------
159
# Now that the bulk of the work has been done, we can combine
160
# them together. Note that in (1) we only save a single buffer
161
# for backward, but this also means we recompute convolution forward
162
# in (5). Also see that in (2), (3), (4), and (6), it's the same
163
# exact code as the examples above.
164
class FusedConvBN2DFunction(torch.autograd.Function):
165
@staticmethod
166
def forward(ctx, X, conv_weight, eps=1e-3):
167
assert X.ndim == 4 # N, C, H, W
168
# (1) Only need to save this single buffer for backward!
169
ctx.save_for_backward(X, conv_weight)
170
171
# (2) Exact same Conv2D forward from example above
172
X = F.conv2d(X, conv_weight)
173
# (3) Exact same BatchNorm2D forward from example above
174
sum = X.sum(dim=(0, 2, 3))
175
var = X.var(unbiased=True, dim=(0, 2, 3))
176
N = X.numel() / X.size(1)
177
sqrt_var = torch.sqrt(var)
178
ctx.eps = eps
179
ctx.sum = sum
180
ctx.N = N
181
ctx.sqrt_var = sqrt_var
182
mean = sum / N
183
denom = sqrt_var + eps
184
# Try to do as many things in-place as possible
185
# Instead of `out = (X - a) / b`, doing `out = X - a; out /= b`
186
# avoids allocating one extra NCHW-sized buffer here
187
out = X - unsqueeze_all(mean)
188
out /= unsqueeze_all(denom)
189
return out
190
191
@staticmethod
192
def backward(ctx, grad_out):
193
X, conv_weight, = ctx.saved_tensors
194
# (4) Batch norm backward
195
# (5) We need to recompute conv
196
X_conv_out = F.conv2d(X, conv_weight)
197
grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var,
198
ctx.N, ctx.eps)
199
# (6) Conv2d backward
200
grad_X, grad_input = convolution_backward(grad_out, X, conv_weight)
201
return grad_X, grad_input, None, None, None, None, None
202
203
######################################################################
204
# The next step is to wrap our functional variant in a stateful
205
# `nn.Module`
206
import torch.nn as nn
207
import math
208
209
class FusedConvBN(nn.Module):
210
def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1,
211
eps=1e-3, device=None, dtype=None):
212
super(FusedConvBN, self).__init__()
213
factory_kwargs = {'device': device, 'dtype': dtype}
214
# Conv parameters
215
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
216
self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs))
217
# Batch norm parameters
218
num_features = out_channels
219
self.num_features = num_features
220
self.eps = eps
221
# Initialize
222
self.reset_parameters()
223
224
def forward(self, X):
225
return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps)
226
227
def reset_parameters(self) -> None:
228
nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))
229
230
######################################################################
231
# Use ``gradcheck`` to validate the correctness of our backward formula
232
weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
233
X = torch.rand(2, 3, 4, 4, requires_grad=True, dtype=torch.double)
234
torch.autograd.gradcheck(FusedConvBN2DFunction.apply, (X, weight))
235
236
######################################################################
237
# Testing out our new Layer
238
# -------------------------------------------------------------------
239
# Use ``FusedConvBN`` to train a basic network
240
# The code below is after some light modifications to the example here:
241
# https://github.com/pytorch/examples/tree/master/mnist
242
import torch.optim as optim
243
from torchvision import datasets, transforms
244
from torch.optim.lr_scheduler import StepLR
245
246
# Record memory allocated at the end of the forward pass
247
memory_allocated = [[],[]]
248
249
class Net(nn.Module):
250
def __init__(self, fused=True):
251
super(Net, self).__init__()
252
self.fused = fused
253
if fused:
254
self.convbn1 = FusedConvBN(1, 32, 3)
255
self.convbn2 = FusedConvBN(32, 64, 3)
256
else:
257
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
258
self.bn1 = nn.BatchNorm2d(32, affine=False, track_running_stats=False)
259
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
260
self.bn2 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)
261
self.fc1 = nn.Linear(9216, 128)
262
self.dropout = nn.Dropout(0.5)
263
self.fc2 = nn.Linear(128, 10)
264
265
def forward(self, x):
266
if self.fused:
267
x = self.convbn1(x)
268
else:
269
x = self.conv1(x)
270
x = self.bn1(x)
271
F.relu_(x)
272
if self.fused:
273
x = self.convbn2(x)
274
else:
275
x = self.conv2(x)
276
x = self.bn2(x)
277
F.relu_(x)
278
x = F.max_pool2d(x, 2)
279
F.relu_(x)
280
x = x.flatten(1)
281
x = self.fc1(x)
282
x = self.dropout(x)
283
F.relu_(x)
284
x = self.fc2(x)
285
output = F.log_softmax(x, dim=1)
286
if fused:
287
memory_allocated[0].append(torch.cuda.memory_allocated())
288
else:
289
memory_allocated[1].append(torch.cuda.memory_allocated())
290
return output
291
292
def train(model, device, train_loader, optimizer, epoch):
293
model.train()
294
for batch_idx, (data, target) in enumerate(train_loader):
295
data, target = data.to(device), target.to(device)
296
optimizer.zero_grad()
297
output = model(data)
298
loss = F.nll_loss(output, target)
299
loss.backward()
300
optimizer.step()
301
if batch_idx % 2 == 0:
302
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
303
epoch, batch_idx * len(data), len(train_loader.dataset),
304
100. * batch_idx / len(train_loader), loss.item()))
305
306
def test(model, device, test_loader):
307
model.eval()
308
test_loss = 0
309
correct = 0
310
# Use inference mode instead of no_grad, for free improved test-time performance
311
with torch.inference_mode():
312
for data, target in test_loader:
313
data, target = data.to(device), target.to(device)
314
output = model(data)
315
# sum up batch loss
316
test_loss += F.nll_loss(output, target, reduction='sum').item()
317
# get the index of the max log-probability
318
pred = output.argmax(dim=1, keepdim=True)
319
correct += pred.eq(target.view_as(pred)).sum().item()
320
321
test_loss /= len(test_loader.dataset)
322
323
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
324
test_loss, correct, len(test_loader.dataset),
325
100. * correct / len(test_loader.dataset)))
326
327
use_cuda = torch.cuda.is_available()
328
device = torch.device("cuda" if use_cuda else "cpu")
329
train_kwargs = {'batch_size': 2048}
330
test_kwargs = {'batch_size': 2048}
331
332
if use_cuda:
333
cuda_kwargs = {'num_workers': 1,
334
'pin_memory': True,
335
'shuffle': True}
336
train_kwargs.update(cuda_kwargs)
337
test_kwargs.update(cuda_kwargs)
338
339
transform = transforms.Compose([
340
transforms.ToTensor(),
341
transforms.Normalize((0.1307,), (0.3081,))
342
])
343
dataset1 = datasets.MNIST('../data', train=True, download=True,
344
transform=transform)
345
dataset2 = datasets.MNIST('../data', train=False,
346
transform=transform)
347
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
348
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
349
350
######################################################################
351
# A Comparison of Memory Usage
352
# -------------------------------------------------------------------
353
# If CUDA is enabled, print out memory usage for both `fused=True` and `fused=False`
354
# For an example run on NVIDIA GeForce RTX 3070, NVIDIA CUDA® Deep Neural Network library (cuDNN) 8.0.5: fused peak memory: 1.56GB,
355
# unfused peak memory: 2.68GB
356
#
357
# It is important to note that the *peak* memory usage for this model may vary depending
358
# the specific cuDNN convolution algorithm used. For shallower models, it
359
# may be possible for the peak memory allocated of the fused model to exceed
360
# that of the unfused model! This is because the memory allocated to compute
361
# certain cuDNN convolution algorithms can be high enough to "hide" the typical peak
362
# you would expect to be near the start of the backward pass.
363
#
364
# For this reason, we also record and display the memory allocated at the end
365
# of the forward pass as an approximation, and to demonstrate that we indeed
366
# allocate one fewer buffer per fused ``conv-bn`` pair.
367
from statistics import mean
368
369
torch.backends.cudnn.enabled = True
370
371
if use_cuda:
372
peak_memory_allocated = []
373
374
for fused in (True, False):
375
torch.manual_seed(123456)
376
377
model = Net(fused=fused).to(device)
378
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
379
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
380
381
for epoch in range(1):
382
train(model, device, train_loader, optimizer, epoch)
383
test(model, device, test_loader)
384
scheduler.step()
385
peak_memory_allocated.append(torch.cuda.max_memory_allocated())
386
torch.cuda.reset_peak_memory_stats()
387
print("cuDNN version:", torch.backends.cudnn.version())
388
print()
389
print("Peak memory allocated:")
390
print(f"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB")
391
print("Memory allocated at end of forward pass:")
392
print(f"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB")
393
394
395
396