CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/model_parallel_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Single-Machine Model Parallel Best Practices
4
============================================
5
**Author**: `Shen Li <https://mrshenli.github.io/>`_
6
7
Model parallel is widely-used in distributed training
8
techniques. Previous posts have explained how to use
9
`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`_
10
to train a neural network on multiple GPUs; this feature replicates the
11
same model to all GPUs, where each GPU consumes a different partition of the
12
input data. Although it can significantly accelerate the training process, it
13
does not work for some use cases where the model is too large to fit into a
14
single GPU. This post shows how to solve that problem by using **model parallel**,
15
which, in contrast to ``DataParallel``, splits a single model onto different GPUs,
16
rather than replicating the entire model on each GPU (to be concrete, say a model
17
``m`` contains 10 layers: when using ``DataParallel``, each GPU will have a
18
replica of each of these 10 layers, whereas when using model parallel on two GPUs,
19
each GPU could host 5 layers).
20
21
The high-level idea of model parallel is to place different sub-networks of a
22
model onto different devices, and implement the ``forward`` method accordingly
23
to move intermediate outputs across devices. As only part of a model operates
24
on any individual device, a set of devices can collectively serve a larger
25
model. In this post, we will not try to construct huge models and squeeze them
26
into a limited number of GPUs. Instead, this post focuses on showing the idea
27
of model parallel. It is up to the readers to apply the ideas to real-world
28
applications.
29
30
.. note::
31
32
For distributed model parallel training where a model spans multiple
33
servers, please refer to
34
`Getting Started With Distributed RPC Framework <rpc_tutorial.html>`__
35
for examples and details.
36
37
Basic Usage
38
-----------
39
"""
40
41
######################################################################
42
# Let us start with a toy model that contains two linear layers. To run this
43
# model on two GPUs, simply put each linear layer on a different GPU, and move
44
# inputs and intermediate outputs to match the layer devices accordingly.
45
#
46
47
import torch
48
import torch.nn as nn
49
import torch.optim as optim
50
51
52
class ToyModel(nn.Module):
53
def __init__(self):
54
super(ToyModel, self).__init__()
55
self.net1 = torch.nn.Linear(10, 10).to('cuda:0')
56
self.relu = torch.nn.ReLU()
57
self.net2 = torch.nn.Linear(10, 5).to('cuda:1')
58
59
def forward(self, x):
60
x = self.relu(self.net1(x.to('cuda:0')))
61
return self.net2(x.to('cuda:1'))
62
63
######################################################################
64
# Note that, the above ``ToyModel`` looks very similar to how one would
65
# implement it on a single GPU, except the four ``to(device)`` calls which
66
# place linear layers and tensors on proper devices. That is the only place in
67
# the model that requires changes. The ``backward()`` and ``torch.optim`` will
68
# automatically take care of gradients as if the model is on one GPU. You only
69
# need to make sure that the labels are on the same device as the outputs when
70
# calling the loss function.
71
72
73
model = ToyModel()
74
loss_fn = nn.MSELoss()
75
optimizer = optim.SGD(model.parameters(), lr=0.001)
76
77
optimizer.zero_grad()
78
outputs = model(torch.randn(20, 10))
79
labels = torch.randn(20, 5).to('cuda:1')
80
loss_fn(outputs, labels).backward()
81
optimizer.step()
82
83
######################################################################
84
# Apply Model Parallel to Existing Modules
85
# ----------------------------------------
86
#
87
# It is also possible to run an existing single-GPU module on multiple GPUs
88
# with just a few lines of changes. The code below shows how to decompose
89
# ``torchvision.models.resnet50()`` to two GPUs. The idea is to inherit from
90
# the existing ``ResNet`` module, and split the layers to two GPUs during
91
# construction. Then, override the ``forward`` method to stitch two
92
# sub-networks by moving the intermediate outputs accordingly.
93
94
95
from torchvision.models.resnet import ResNet, Bottleneck
96
97
num_classes = 1000
98
99
100
class ModelParallelResNet50(ResNet):
101
def __init__(self, *args, **kwargs):
102
super(ModelParallelResNet50, self).__init__(
103
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)
104
105
self.seq1 = nn.Sequential(
106
self.conv1,
107
self.bn1,
108
self.relu,
109
self.maxpool,
110
111
self.layer1,
112
self.layer2
113
).to('cuda:0')
114
115
self.seq2 = nn.Sequential(
116
self.layer3,
117
self.layer4,
118
self.avgpool,
119
).to('cuda:1')
120
121
self.fc.to('cuda:1')
122
123
def forward(self, x):
124
x = self.seq2(self.seq1(x).to('cuda:1'))
125
return self.fc(x.view(x.size(0), -1))
126
127
128
######################################################################
129
# The above implementation solves the problem for cases where the model is too
130
# large to fit into a single GPU. However, you might have already noticed that
131
# it will be slower than running it on a single GPU if your model fits. It is
132
# because, at any point in time, only one of the two GPUs are working, while
133
# the other one is sitting there doing nothing. The performance further
134
# deteriorates as the intermediate outputs need to be copied from ``cuda:0`` to
135
# ``cuda:1`` between ``layer2`` and ``layer3``.
136
#
137
# Let us run an experiment to get a more quantitative view of the execution
138
# time. In this experiment, we train ``ModelParallelResNet50`` and the existing
139
# ``torchvision.models.resnet50()`` by running random inputs and labels through
140
# them. After the training, the models will not produce any useful predictions,
141
# but we can get a reasonable understanding of the execution times.
142
143
144
import torchvision.models as models
145
146
num_batches = 3
147
batch_size = 120
148
image_w = 128
149
image_h = 128
150
151
152
def train(model):
153
model.train(True)
154
loss_fn = nn.MSELoss()
155
optimizer = optim.SGD(model.parameters(), lr=0.001)
156
157
one_hot_indices = torch.LongTensor(batch_size) \
158
.random_(0, num_classes) \
159
.view(batch_size, 1)
160
161
for _ in range(num_batches):
162
# generate random inputs and labels
163
inputs = torch.randn(batch_size, 3, image_w, image_h)
164
labels = torch.zeros(batch_size, num_classes) \
165
.scatter_(1, one_hot_indices, 1)
166
167
# run forward pass
168
optimizer.zero_grad()
169
outputs = model(inputs.to('cuda:0'))
170
171
# run backward pass
172
labels = labels.to(outputs.device)
173
loss_fn(outputs, labels).backward()
174
optimizer.step()
175
176
177
######################################################################
178
# The ``train(model)`` method above uses ``nn.MSELoss`` as the loss function,
179
# and ``optim.SGD`` as the optimizer. It mimics training on ``128 X 128``
180
# images which are organized into 3 batches where each batch contains 120
181
# images. Then, we use ``timeit`` to run the ``train(model)`` method 10 times
182
# and plot the execution times with standard deviations.
183
184
185
import matplotlib.pyplot as plt
186
plt.switch_backend('Agg')
187
import numpy as np
188
import timeit
189
190
num_repeat = 10
191
192
stmt = "train(model)"
193
194
setup = "model = ModelParallelResNet50()"
195
mp_run_times = timeit.repeat(
196
stmt, setup, number=1, repeat=num_repeat, globals=globals())
197
mp_mean, mp_std = np.mean(mp_run_times), np.std(mp_run_times)
198
199
setup = "import torchvision.models as models;" + \
200
"model = models.resnet50(num_classes=num_classes).to('cuda:0')"
201
rn_run_times = timeit.repeat(
202
stmt, setup, number=1, repeat=num_repeat, globals=globals())
203
rn_mean, rn_std = np.mean(rn_run_times), np.std(rn_run_times)
204
205
206
def plot(means, stds, labels, fig_name):
207
fig, ax = plt.subplots()
208
ax.bar(np.arange(len(means)), means, yerr=stds,
209
align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6)
210
ax.set_ylabel('ResNet50 Execution Time (Second)')
211
ax.set_xticks(np.arange(len(means)))
212
ax.set_xticklabels(labels)
213
ax.yaxis.grid(True)
214
plt.tight_layout()
215
plt.savefig(fig_name)
216
plt.close(fig)
217
218
219
plot([mp_mean, rn_mean],
220
[mp_std, rn_std],
221
['Model Parallel', 'Single GPU'],
222
'mp_vs_rn.png')
223
224
225
######################################################################
226
#
227
# .. figure:: /_static/img/model-parallel-images/mp_vs_rn.png
228
# :alt:
229
#
230
# The result shows that the execution time of model parallel implementation is
231
# ``4.02/3.75-1=7%`` longer than the existing single-GPU implementation. So we
232
# can conclude there is roughly 7% overhead in copying tensors back and forth
233
# across the GPUs. There are rooms for improvements, as we know one of the two
234
# GPUs is sitting idle throughout the execution. One option is to further
235
# divide each batch into a pipeline of splits, such that when one split reaches
236
# the second sub-network, the following split can be fed into the first
237
# sub-network. In this way, two consecutive splits can run concurrently on two
238
# GPUs.
239
240
######################################################################
241
# Speed Up by Pipelining Inputs
242
# -----------------------------
243
#
244
# In the following experiments, we further divide each 120-image batch into
245
# 20-image splits. As PyTorch launches CUDA operations asynchronously, the
246
# implementation does not need to spawn multiple threads to achieve
247
# concurrency.
248
249
250
class PipelineParallelResNet50(ModelParallelResNet50):
251
def __init__(self, split_size=20, *args, **kwargs):
252
super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
253
self.split_size = split_size
254
255
def forward(self, x):
256
splits = iter(x.split(self.split_size, dim=0))
257
s_next = next(splits)
258
s_prev = self.seq1(s_next).to('cuda:1')
259
ret = []
260
261
for s_next in splits:
262
# A. ``s_prev`` runs on ``cuda:1``
263
s_prev = self.seq2(s_prev)
264
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
265
266
# B. ``s_next`` runs on ``cuda:0``, which can run concurrently with A
267
s_prev = self.seq1(s_next).to('cuda:1')
268
269
s_prev = self.seq2(s_prev)
270
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
271
272
return torch.cat(ret)
273
274
275
setup = "model = PipelineParallelResNet50()"
276
pp_run_times = timeit.repeat(
277
stmt, setup, number=1, repeat=num_repeat, globals=globals())
278
pp_mean, pp_std = np.mean(pp_run_times), np.std(pp_run_times)
279
280
plot([mp_mean, rn_mean, pp_mean],
281
[mp_std, rn_std, pp_std],
282
['Model Parallel', 'Single GPU', 'Pipelining Model Parallel'],
283
'mp_vs_rn_vs_pp.png')
284
285
######################################################################
286
# Please note, device-to-device tensor copy operations are synchronized on
287
# current streams on the source and the destination devices. If you create
288
# multiple streams, you have to make sure that copy operations are properly
289
# synchronized. Writing the source tensor or reading/writing the destination
290
# tensor before finishing the copy operation can lead to undefined behavior.
291
# The above implementation only uses default streams on both source and
292
# destination devices, hence it is not necessary to enforce additional
293
# synchronizations.
294
#
295
# .. figure:: /_static/img/model-parallel-images/mp_vs_rn_vs_pp.png
296
# :alt:
297
#
298
# The experiment result shows that, pipelining inputs to model parallel
299
# ResNet50 speeds up the training process by roughly ``3.75/2.51-1=49%``. It is
300
# still quite far away from the ideal 100% speedup. As we have introduced a new
301
# parameter ``split_sizes`` in our pipeline parallel implementation, it is
302
# unclear how the new parameter affects the overall training time. Intuitively
303
# speaking, using small ``split_size`` leads to many tiny CUDA kernel launch,
304
# while using large ``split_size`` results to relatively long idle times during
305
# the first and last splits. Neither are optimal. There might be an optimal
306
# ``split_size`` configuration for this specific experiment. Let us try to find
307
# it by running experiments using several different ``split_size`` values.
308
309
310
means = []
311
stds = []
312
split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60]
313
314
for split_size in split_sizes:
315
setup = "model = PipelineParallelResNet50(split_size=%d)" % split_size
316
pp_run_times = timeit.repeat(
317
stmt, setup, number=1, repeat=num_repeat, globals=globals())
318
means.append(np.mean(pp_run_times))
319
stds.append(np.std(pp_run_times))
320
321
fig, ax = plt.subplots()
322
ax.plot(split_sizes, means)
323
ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro')
324
ax.set_ylabel('ResNet50 Execution Time (Second)')
325
ax.set_xlabel('Pipeline Split Size')
326
ax.set_xticks(split_sizes)
327
ax.yaxis.grid(True)
328
plt.tight_layout()
329
plt.savefig("split_size_tradeoff.png")
330
plt.close(fig)
331
332
######################################################################
333
#
334
# .. figure:: /_static/img/model-parallel-images/split_size_tradeoff.png
335
# :alt:
336
#
337
# The result shows that setting ``split_size`` to 12 achieves the fastest
338
# training speed, which leads to ``3.75/2.43-1=54%`` speedup. There are
339
# still opportunities to further accelerate the training process. For example,
340
# all operations on ``cuda:0`` is placed on its default stream. It means that
341
# computations on the next split cannot overlap with the copy operation of the
342
# ``prev`` split. However, as ``prev`` and next splits are different tensors, there is
343
# no problem to overlap one's computation with the other one's copy. The
344
# implementation need to use multiple streams on both GPUs, and different
345
# sub-network structures require different stream management strategies. As no
346
# general multi-stream solution works for all model parallel use cases, we will
347
# not discuss it in this tutorial.
348
#
349
# **Note:**
350
#
351
# This post shows several performance measurements. You might see different
352
# numbers when running the same code on your own machine, because the result
353
# depends on the underlying hardware and software. To get the best performance
354
# for your environment, a proper approach is to first generate the curve to
355
# figure out the best split size, and then use that split size to pipeline
356
# inputs.
357
#
358
359