CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/intermediate_source/custom_function_conv_bn_tutorial.py
Views: 494
# -*- coding: utf-8 -*-1"""2Fusing Convolution and Batch Norm using Custom Function3=======================================================45Fusing adjacent convolution and batch norm layers together is typically an6inference-time optimization to improve run-time. It is usually achieved7by eliminating the batch norm layer entirely and updating the weight8and bias of the preceding convolution [0]. However, this technique is not9applicable for training models.1011In this tutorial, we will show a different technique to fuse the two layers12that can be applied during training. Rather than improved runtime, the13objective of this optimization is to reduce memory usage.1415The idea behind this optimization is to see that both convolution and16batch norm (as well as many other ops) need to save a copy of their input17during forward for the backward pass. For large18batch sizes, these saved inputs are responsible for most of your memory usage,19so being able to avoid allocating another input tensor for every20convolution batch norm pair can be a significant reduction.2122In this tutorial, we avoid this extra allocation by combining convolution23and batch norm into a single layer (as a custom function). In the forward24of this combined layer, we perform normal convolution and batch norm as-is,25with the only difference being that we will only save the inputs to the convolution.26To obtain the input of batch norm, which is necessary to backward through27it, we recompute convolution forward again during the backward pass.2829It is important to note that the usage of this optimization is situational.30Though (by avoiding one buffer saved) we always reduce the memory allocated at31the end of the forward pass, there are cases when the *peak* memory allocated32may not actually be reduced. See the final section for more details.3334For simplicity, in this tutorial we hardcode `bias=False`, `stride=1`, `padding=0`, `dilation=1`,35and `groups=1` for Conv2D. For BatchNorm2D, we hardcode `eps=1e-3`, `momentum=0.1`,36`affine=False`, and `track_running_statistics=False`. Another small difference37is that we add epsilon in the denominator outside of the square root in the computation38of batch norm.3940[0] https://nenadmarkus.com/p/fusing-batchnorm-and-conv/41"""4243######################################################################44# Backward Formula Implementation for Convolution45# -------------------------------------------------------------------46# Implementing a custom function requires us to implement the backward47# ourselves. In this case, we need both the backward formulas for Conv2D48# and BatchNorm2D. Eventually we'd chain them together in our unified49# backward function, but below we first implement them as their own50# custom functions so we can validate their correctness individually51import torch52from torch.autograd.function import once_differentiable53import torch.nn.functional as F5455def convolution_backward(grad_out, X, weight):56grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)57grad_X = F.conv_transpose2d(grad_out, weight)58return grad_X, grad_input5960class Conv2D(torch.autograd.Function):61@staticmethod62def forward(ctx, X, weight):63ctx.save_for_backward(X, weight)64return F.conv2d(X, weight)6566# Use @once_differentiable by default unless we intend to double backward67@staticmethod68@once_differentiable69def backward(ctx, grad_out):70X, weight = ctx.saved_tensors71return convolution_backward(grad_out, X, weight)7273######################################################################74# When testing with ``gradcheck``, it is important to use double precision75weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)76X = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)77torch.autograd.gradcheck(Conv2D.apply, (X, weight))7879######################################################################80# Backward Formula Implementation for Batch Norm81# -------------------------------------------------------------------82# Batch Norm has two modes: training and ``eval`` mode. In training mode83# the sample statistics are a function of the inputs. In ``eval`` mode,84# we use the saved running statistics, which are not a function of the inputs.85# This makes non-training mode's backward significantly simpler. Below86# we implement and test only the training mode case.87def unsqueeze_all(t):88# Helper function to ``unsqueeze`` all the dimensions that we reduce over89return t[None, :, None, None]9091def batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps):92# We use the formula: ``out = (X - mean(X)) / (sqrt(var(X)) + eps)``93# in batch norm 2D forward. To simplify our derivation, we follow the94# chain rule and compute the gradients as follows before accumulating95# them all into a final grad_input.96# 1) ``grad of out wrt var(X)`` * ``grad of var(X) wrt X``97# 2) ``grad of out wrt mean(X)`` * ``grad of mean(X) wrt X``98# 3) ``grad of out wrt X in the numerator`` * ``grad of X wrt X``99# We then rewrite the formulas to use as few extra buffers as possible100tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3))101tmp *= -1102d_denom = tmp / (sqrt_var + eps)**2 # ``d_denom = -num / denom**2``103# It is useful to delete tensors when you no longer need them with ``del``104# For example, we could've done ``del tmp`` here because we won't use it later105# In this case, it's not a big difference because ``tmp`` only has size of (C,)106# The important thing is avoid allocating NCHW-sized tensors unnecessarily107d_var = d_denom / (2 * sqrt_var) # ``denom = torch.sqrt(var) + eps``108# Compute ``d_mean_dx`` before allocating the final NCHW-sized grad_input buffer109d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps)110d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N)111# ``d_mean_dx`` has already been reassigned to a C-sized buffer so no need to worry112113# ``(1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1)``114grad_input = X * unsqueeze_all(d_var * N)115grad_input += unsqueeze_all(-d_var * sum)116grad_input *= 2 / ((N - 1) * N)117# (2) mean (see above)118grad_input += d_mean_dx119# (3) Add 'grad_out / <factor>' without allocating an extra buffer120grad_input *= unsqueeze_all(sqrt_var + eps)121grad_input += grad_out122grad_input /= unsqueeze_all(sqrt_var + eps) # ``sqrt_var + eps > 0!``123return grad_input124125class BatchNorm(torch.autograd.Function):126@staticmethod127def forward(ctx, X, eps=1e-3):128# Don't save ``keepdim`` values for backward129sum = X.sum(dim=(0, 2, 3))130var = X.var(unbiased=True, dim=(0, 2, 3))131N = X.numel() / X.size(1)132sqrt_var = torch.sqrt(var)133ctx.save_for_backward(X)134ctx.eps = eps135ctx.sum = sum136ctx.N = N137ctx.sqrt_var = sqrt_var138mean = sum / N139denom = sqrt_var + eps140out = X - unsqueeze_all(mean)141out /= unsqueeze_all(denom)142return out143144@staticmethod145@once_differentiable146def backward(ctx, grad_out):147X, = ctx.saved_tensors148return batch_norm_backward(grad_out, X, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps)149150######################################################################151# Testing with ``gradcheck``152a = torch.rand(1, 2, 3, 4, requires_grad=True, dtype=torch.double)153torch.autograd.gradcheck(BatchNorm.apply, (a,), fast_mode=False)154155######################################################################156# Fusing Convolution and BatchNorm157# -------------------------------------------------------------------158# Now that the bulk of the work has been done, we can combine159# them together. Note that in (1) we only save a single buffer160# for backward, but this also means we recompute convolution forward161# in (5). Also see that in (2), (3), (4), and (6), it's the same162# exact code as the examples above.163class FusedConvBN2DFunction(torch.autograd.Function):164@staticmethod165def forward(ctx, X, conv_weight, eps=1e-3):166assert X.ndim == 4 # N, C, H, W167# (1) Only need to save this single buffer for backward!168ctx.save_for_backward(X, conv_weight)169170# (2) Exact same Conv2D forward from example above171X = F.conv2d(X, conv_weight)172# (3) Exact same BatchNorm2D forward from example above173sum = X.sum(dim=(0, 2, 3))174var = X.var(unbiased=True, dim=(0, 2, 3))175N = X.numel() / X.size(1)176sqrt_var = torch.sqrt(var)177ctx.eps = eps178ctx.sum = sum179ctx.N = N180ctx.sqrt_var = sqrt_var181mean = sum / N182denom = sqrt_var + eps183# Try to do as many things in-place as possible184# Instead of `out = (X - a) / b`, doing `out = X - a; out /= b`185# avoids allocating one extra NCHW-sized buffer here186out = X - unsqueeze_all(mean)187out /= unsqueeze_all(denom)188return out189190@staticmethod191def backward(ctx, grad_out):192X, conv_weight, = ctx.saved_tensors193# (4) Batch norm backward194# (5) We need to recompute conv195X_conv_out = F.conv2d(X, conv_weight)196grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var,197ctx.N, ctx.eps)198# (6) Conv2d backward199grad_X, grad_input = convolution_backward(grad_out, X, conv_weight)200return grad_X, grad_input, None, None, None, None, None201202######################################################################203# The next step is to wrap our functional variant in a stateful204# `nn.Module`205import torch.nn as nn206import math207208class FusedConvBN(nn.Module):209def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1,210eps=1e-3, device=None, dtype=None):211super(FusedConvBN, self).__init__()212factory_kwargs = {'device': device, 'dtype': dtype}213# Conv parameters214weight_shape = (out_channels, in_channels, kernel_size, kernel_size)215self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs))216# Batch norm parameters217num_features = out_channels218self.num_features = num_features219self.eps = eps220# Initialize221self.reset_parameters()222223def forward(self, X):224return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps)225226def reset_parameters(self) -> None:227nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))228229######################################################################230# Use ``gradcheck`` to validate the correctness of our backward formula231weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)232X = torch.rand(2, 3, 4, 4, requires_grad=True, dtype=torch.double)233torch.autograd.gradcheck(FusedConvBN2DFunction.apply, (X, weight))234235######################################################################236# Testing out our new Layer237# -------------------------------------------------------------------238# Use ``FusedConvBN`` to train a basic network239# The code below is after some light modifications to the example here:240# https://github.com/pytorch/examples/tree/master/mnist241import torch.optim as optim242from torchvision import datasets, transforms243from torch.optim.lr_scheduler import StepLR244245# Record memory allocated at the end of the forward pass246memory_allocated = [[],[]]247248class Net(nn.Module):249def __init__(self, fused=True):250super(Net, self).__init__()251self.fused = fused252if fused:253self.convbn1 = FusedConvBN(1, 32, 3)254self.convbn2 = FusedConvBN(32, 64, 3)255else:256self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)257self.bn1 = nn.BatchNorm2d(32, affine=False, track_running_stats=False)258self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)259self.bn2 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)260self.fc1 = nn.Linear(9216, 128)261self.dropout = nn.Dropout(0.5)262self.fc2 = nn.Linear(128, 10)263264def forward(self, x):265if self.fused:266x = self.convbn1(x)267else:268x = self.conv1(x)269x = self.bn1(x)270F.relu_(x)271if self.fused:272x = self.convbn2(x)273else:274x = self.conv2(x)275x = self.bn2(x)276F.relu_(x)277x = F.max_pool2d(x, 2)278F.relu_(x)279x = x.flatten(1)280x = self.fc1(x)281x = self.dropout(x)282F.relu_(x)283x = self.fc2(x)284output = F.log_softmax(x, dim=1)285if fused:286memory_allocated[0].append(torch.cuda.memory_allocated())287else:288memory_allocated[1].append(torch.cuda.memory_allocated())289return output290291def train(model, device, train_loader, optimizer, epoch):292model.train()293for batch_idx, (data, target) in enumerate(train_loader):294data, target = data.to(device), target.to(device)295optimizer.zero_grad()296output = model(data)297loss = F.nll_loss(output, target)298loss.backward()299optimizer.step()300if batch_idx % 2 == 0:301print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(302epoch, batch_idx * len(data), len(train_loader.dataset),303100. * batch_idx / len(train_loader), loss.item()))304305def test(model, device, test_loader):306model.eval()307test_loss = 0308correct = 0309# Use inference mode instead of no_grad, for free improved test-time performance310with torch.inference_mode():311for data, target in test_loader:312data, target = data.to(device), target.to(device)313output = model(data)314# sum up batch loss315test_loss += F.nll_loss(output, target, reduction='sum').item()316# get the index of the max log-probability317pred = output.argmax(dim=1, keepdim=True)318correct += pred.eq(target.view_as(pred)).sum().item()319320test_loss /= len(test_loader.dataset)321322print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(323test_loss, correct, len(test_loader.dataset),324100. * correct / len(test_loader.dataset)))325326use_cuda = torch.cuda.is_available()327device = torch.device("cuda" if use_cuda else "cpu")328train_kwargs = {'batch_size': 2048}329test_kwargs = {'batch_size': 2048}330331if use_cuda:332cuda_kwargs = {'num_workers': 1,333'pin_memory': True,334'shuffle': True}335train_kwargs.update(cuda_kwargs)336test_kwargs.update(cuda_kwargs)337338transform = transforms.Compose([339transforms.ToTensor(),340transforms.Normalize((0.1307,), (0.3081,))341])342dataset1 = datasets.MNIST('../data', train=True, download=True,343transform=transform)344dataset2 = datasets.MNIST('../data', train=False,345transform=transform)346train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)347test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)348349######################################################################350# A Comparison of Memory Usage351# -------------------------------------------------------------------352# If CUDA is enabled, print out memory usage for both `fused=True` and `fused=False`353# For an example run on NVIDIA GeForce RTX 3070, NVIDIA CUDA® Deep Neural Network library (cuDNN) 8.0.5: fused peak memory: 1.56GB,354# unfused peak memory: 2.68GB355#356# It is important to note that the *peak* memory usage for this model may vary depending357# the specific cuDNN convolution algorithm used. For shallower models, it358# may be possible for the peak memory allocated of the fused model to exceed359# that of the unfused model! This is because the memory allocated to compute360# certain cuDNN convolution algorithms can be high enough to "hide" the typical peak361# you would expect to be near the start of the backward pass.362#363# For this reason, we also record and display the memory allocated at the end364# of the forward pass as an approximation, and to demonstrate that we indeed365# allocate one fewer buffer per fused ``conv-bn`` pair.366from statistics import mean367368torch.backends.cudnn.enabled = True369370if use_cuda:371peak_memory_allocated = []372373for fused in (True, False):374torch.manual_seed(123456)375376model = Net(fused=fused).to(device)377optimizer = optim.Adadelta(model.parameters(), lr=1.0)378scheduler = StepLR(optimizer, step_size=1, gamma=0.7)379380for epoch in range(1):381train(model, device, train_loader, optimizer, epoch)382test(model, device, test_loader)383scheduler.step()384peak_memory_allocated.append(torch.cuda.max_memory_allocated())385torch.cuda.reset_peak_memory_stats()386print("cuDNN version:", torch.backends.cudnn.version())387print()388print("Peak memory allocated:")389print(f"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB")390print("Memory allocated at end of forward pass:")391print(f"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB")392393394395396