CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/ensembling.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Model ensembling
4
================
5
6
This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``.
7
8
What is model ensembling?
9
-------------------------
10
Model ensembling combines the predictions from multiple models together.
11
Traditionally this is done by running each model on some inputs separately
12
and then combining the predictions. However, if you're running models with
13
the same architecture, then it may be possible to combine them together
14
using ``torch.vmap``. ``vmap`` is a function transform that maps functions across
15
dimensions of the input tensors. One of its use cases is eliminating
16
for-loops and speeding them up through vectorization.
17
18
Let's demonstrate how to do this using an ensemble of simple MLPs.
19
20
.. note::
21
22
This tutorial requires PyTorch 2.0.0 or later.
23
"""
24
25
import torch
26
import torch.nn as nn
27
import torch.nn.functional as F
28
torch.manual_seed(0)
29
30
# Here's a simple MLP
31
class SimpleMLP(nn.Module):
32
def __init__(self):
33
super(SimpleMLP, self).__init__()
34
self.fc1 = nn.Linear(784, 128)
35
self.fc2 = nn.Linear(128, 128)
36
self.fc3 = nn.Linear(128, 10)
37
38
def forward(self, x):
39
x = x.flatten(1)
40
x = self.fc1(x)
41
x = F.relu(x)
42
x = self.fc2(x)
43
x = F.relu(x)
44
x = self.fc3(x)
45
return x
46
47
######################################################################
48
# Let’s generate a batch of dummy data and pretend that we’re working with
49
# an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a
50
# minibatch of size 64. Furthermore, lets say we want to combine the predictions
51
# from 10 different models.
52
53
device = 'cuda'
54
num_models = 10
55
56
data = torch.randn(100, 64, 1, 28, 28, device=device)
57
targets = torch.randint(10, (6400,), device=device)
58
59
models = [SimpleMLP().to(device) for _ in range(num_models)]
60
61
######################################################################
62
# We have a couple of options for generating predictions. Maybe we want to
63
# give each model a different randomized minibatch of data. Alternatively,
64
# maybe we want to run the same minibatch of data through each model (e.g.
65
# if we were testing the effect of different model initializations).
66
67
######################################################################
68
# Option 1: different minibatch for each model
69
70
minibatches = data[:num_models]
71
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
72
73
######################################################################
74
# Option 2: Same minibatch
75
76
minibatch = data[0]
77
predictions2 = [model(minibatch) for model in models]
78
79
######################################################################
80
# Using ``vmap`` to vectorize the ensemble
81
# ----------------------------------------
82
#
83
# Let's use ``vmap`` to speed up the for-loop. We must first prepare the models
84
# for use with ``vmap``.
85
#
86
# First, let’s combine the states of the model together by stacking each
87
# parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are
88
# going to stack the ``.fc1.weight`` of each of the 10 models to produce a big
89
# weight of shape ``[10, 784, 128]``.
90
#
91
# PyTorch offers the ``torch.func.stack_module_state`` convenience function to do
92
# this.
93
from torch.func import stack_module_state
94
95
params, buffers = stack_module_state(models)
96
97
######################################################################
98
# Next, we need to define a function to ``vmap`` over. The function should,
99
# given parameters and buffers and inputs, run the model using those
100
# parameters, buffers, and inputs. We'll use ``torch.func.functional_call``
101
# to help out:
102
103
from torch.func import functional_call
104
import copy
105
106
# Construct a "stateless" version of one of the models. It is "stateless" in
107
# the sense that the parameters are meta Tensors and do not have storage.
108
base_model = copy.deepcopy(models[0])
109
base_model = base_model.to('meta')
110
111
def fmodel(params, buffers, x):
112
return functional_call(base_model, (params, buffers), (x,))
113
114
######################################################################
115
# Option 1: get predictions using a different minibatch for each model.
116
#
117
# By default, ``vmap`` maps a function across the first dimension of all inputs to
118
# the passed-in function. After using ``stack_module_state``, each of
119
# the ``params`` and buffers have an additional dimension of size 'num_models' at
120
# the front, and minibatches has a dimension of size 'num_models'.
121
122
print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension
123
124
assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'
125
126
from torch import vmap
127
128
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
129
130
# verify the ``vmap`` predictions match the
131
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
132
133
######################################################################
134
# Option 2: get predictions using the same minibatch of data.
135
#
136
# ``vmap`` has an ``in_dims`` argument that specifies which dimensions to map over.
137
# By using ``None``, we tell ``vmap`` we want the same minibatch to apply for all of
138
# the 10 models.
139
140
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
141
142
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
143
144
######################################################################
145
# A quick note: there are limitations around what types of functions can be
146
# transformed by ``vmap``. The best functions to transform are ones that are pure
147
# functions: a function where the outputs are only determined by the inputs
148
# that have no side effects (e.g. mutation). ``vmap`` is unable to handle mutation
149
# of arbitrary Python data structures, but it is able to handle many in-place
150
# PyTorch operations.
151
152
######################################################################
153
# Performance
154
# -----------
155
# Curious about performance numbers? Here's how the numbers look.
156
157
from torch.utils.benchmark import Timer
158
without_vmap = Timer(
159
stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
160
globals=globals())
161
with_vmap = Timer(
162
stmt="vmap(fmodel)(params, buffers, minibatches)",
163
globals=globals())
164
print(f'Predictions without vmap {without_vmap.timeit(100)}')
165
print(f'Predictions with vmap {with_vmap.timeit(100)}')
166
167
######################################################################
168
# There's a large speedup using ``vmap``!
169
#
170
# In general, vectorization with ``vmap`` should be faster than running a function
171
# in a for-loop and competitive with manual batching. There are some exceptions
172
# though, like if we haven’t implemented the ``vmap`` rule for a particular
173
# operation or if the underlying kernels weren’t optimized for older hardware
174
# (GPUs). If you see any of these cases, please let us know by opening an issue
175
# on GitHub.
176
177