Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/ensembling.py
Views: 712
# -*- coding: utf-8 -*-1"""2Model ensembling3================45This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``.67What is model ensembling?8-------------------------9Model ensembling combines the predictions from multiple models together.10Traditionally this is done by running each model on some inputs separately11and then combining the predictions. However, if you're running models with12the same architecture, then it may be possible to combine them together13using ``torch.vmap``. ``vmap`` is a function transform that maps functions across14dimensions of the input tensors. One of its use cases is eliminating15for-loops and speeding them up through vectorization.1617Let's demonstrate how to do this using an ensemble of simple MLPs.1819.. note::2021This tutorial requires PyTorch 2.0.0 or later.22"""2324import torch25import torch.nn as nn26import torch.nn.functional as F27torch.manual_seed(0)2829# Here's a simple MLP30class SimpleMLP(nn.Module):31def __init__(self):32super(SimpleMLP, self).__init__()33self.fc1 = nn.Linear(784, 128)34self.fc2 = nn.Linear(128, 128)35self.fc3 = nn.Linear(128, 10)3637def forward(self, x):38x = x.flatten(1)39x = self.fc1(x)40x = F.relu(x)41x = self.fc2(x)42x = F.relu(x)43x = self.fc3(x)44return x4546######################################################################47# Let’s generate a batch of dummy data and pretend that we’re working with48# an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a49# minibatch of size 64. Furthermore, lets say we want to combine the predictions50# from 10 different models.5152device = 'cuda'53num_models = 105455data = torch.randn(100, 64, 1, 28, 28, device=device)56targets = torch.randint(10, (6400,), device=device)5758models = [SimpleMLP().to(device) for _ in range(num_models)]5960######################################################################61# We have a couple of options for generating predictions. Maybe we want to62# give each model a different randomized minibatch of data. Alternatively,63# maybe we want to run the same minibatch of data through each model (e.g.64# if we were testing the effect of different model initializations).6566######################################################################67# Option 1: different minibatch for each model6869minibatches = data[:num_models]70predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]7172######################################################################73# Option 2: Same minibatch7475minibatch = data[0]76predictions2 = [model(minibatch) for model in models]7778######################################################################79# Using ``vmap`` to vectorize the ensemble80# ----------------------------------------81#82# Let's use ``vmap`` to speed up the for-loop. We must first prepare the models83# for use with ``vmap``.84#85# First, let’s combine the states of the model together by stacking each86# parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are87# going to stack the ``.fc1.weight`` of each of the 10 models to produce a big88# weight of shape ``[10, 784, 128]``.89#90# PyTorch offers the ``torch.func.stack_module_state`` convenience function to do91# this.92from torch.func import stack_module_state9394params, buffers = stack_module_state(models)9596######################################################################97# Next, we need to define a function to ``vmap`` over. The function should,98# given parameters and buffers and inputs, run the model using those99# parameters, buffers, and inputs. We'll use ``torch.func.functional_call``100# to help out:101102from torch.func import functional_call103import copy104105# Construct a "stateless" version of one of the models. It is "stateless" in106# the sense that the parameters are meta Tensors and do not have storage.107base_model = copy.deepcopy(models[0])108base_model = base_model.to('meta')109110def fmodel(params, buffers, x):111return functional_call(base_model, (params, buffers), (x,))112113######################################################################114# Option 1: get predictions using a different minibatch for each model.115#116# By default, ``vmap`` maps a function across the first dimension of all inputs to117# the passed-in function. After using ``stack_module_state``, each of118# the ``params`` and buffers have an additional dimension of size 'num_models' at119# the front, and minibatches has a dimension of size 'num_models'.120121print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension122123assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'124125from torch import vmap126127predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)128129# verify the ``vmap`` predictions match the130assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)131132######################################################################133# Option 2: get predictions using the same minibatch of data.134#135# ``vmap`` has an ``in_dims`` argument that specifies which dimensions to map over.136# By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of137# the 10 models.138139predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)140141assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)142143######################################################################144# A quick note: there are limitations around what types of functions can be145# transformed by ``vmap``. The best functions to transform are ones that are pure146# functions: a function where the outputs are only determined by the inputs147# that have no side effects (e.g. mutation). ``vmap`` is unable to handle mutation148# of arbitrary Python data structures, but it is able to handle many in-place149# PyTorch operations.150151######################################################################152# Performance153# -----------154# Curious about performance numbers? Here's how the numbers look.155156from torch.utils.benchmark import Timer157without_vmap = Timer(158stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",159globals=globals())160with_vmap = Timer(161stmt="vmap(fmodel)(params, buffers, minibatches)",162globals=globals())163print(f'Predictions without vmap {without_vmap.timeit(100)}')164print(f'Predictions with vmap {with_vmap.timeit(100)}')165166######################################################################167# There's a large speedup using ``vmap``!168#169# In general, vectorization with ``vmap`` should be faster than running a function170# in a for-loop and competitive with manual batching. There are some exceptions171# though, like if we haven’t implemented the ``vmap`` rule for a particular172# operation or if the underlying kernels weren’t optimized for older hardware173# (GPUs). If you see any of these cases, please let us know by opening an issue174# on GitHub.175176177