CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/amp_recipe.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Automatic Mixed Precision
4
*************************
5
**Author**: `Michael Carilli <https://github.com/mcarilli>`_
6
7
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ provides convenience methods for mixed precision,
8
where some operations use the ``torch.float32`` (``float``) datatype and other operations
9
use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions,
10
are much faster in ``float16`` or ``bfloat16``. Other ops, like reductions, often require the dynamic
11
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype,
12
which can reduce your network's runtime and memory footprint.
13
14
Ordinarily, "automatic mixed precision training" uses `torch.autocast <https://pytorch.org/docs/stable/amp.html#torch.autocast>`_ and
15
`torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_ together.
16
17
This recipe measures the performance of a simple network in default precision,
18
then walks through adding ``autocast`` and ``GradScaler`` to run the same network in
19
mixed precision with improved performance.
20
21
You may download and run this recipe as a standalone Python script.
22
The only requirements are PyTorch 1.6 or later and a CUDA-capable GPU.
23
24
Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere).
25
This recipe should show significant (2-3X) speedup on those architectures.
26
On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup.
27
Run ``nvidia-smi`` to display your GPU's architecture.
28
"""
29
30
import torch, time, gc
31
32
# Timing utilities
33
start_time = None
34
35
def start_timer():
36
global start_time
37
gc.collect()
38
torch.cuda.empty_cache()
39
torch.cuda.reset_max_memory_allocated()
40
torch.cuda.synchronize()
41
start_time = time.time()
42
43
def end_timer_and_print(local_msg):
44
torch.cuda.synchronize()
45
end_time = time.time()
46
print("\n" + local_msg)
47
print("Total execution time = {:.3f} sec".format(end_time - start_time))
48
print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))
49
50
##########################################################
51
# A simple network
52
# ----------------
53
# The following sequence of linear layers and ReLUs should show a speedup with mixed precision.
54
55
def make_model(in_size, out_size, num_layers):
56
layers = []
57
for _ in range(num_layers - 1):
58
layers.append(torch.nn.Linear(in_size, in_size))
59
layers.append(torch.nn.ReLU())
60
layers.append(torch.nn.Linear(in_size, out_size))
61
return torch.nn.Sequential(*tuple(layers)).cuda()
62
63
##########################################################
64
# ``batch_size``, ``in_size``, ``out_size``, and ``num_layers`` are chosen to be large enough to saturate the GPU with work.
65
# Typically, mixed precision provides the greatest speedup when the GPU is saturated.
66
# Small networks may be CPU bound, in which case mixed precision won't improve performance.
67
# Sizes are also chosen such that linear layers' participating dimensions are multiples of 8,
68
# to permit Tensor Core usage on Tensor Core-capable GPUs (see :ref:`Troubleshooting<troubleshooting>` below).
69
#
70
# Exercise: Vary participating sizes and see how the mixed precision speedup changes.
71
72
batch_size = 512 # Try, for example, 128, 256, 513.
73
in_size = 4096
74
out_size = 4096
75
num_layers = 3
76
num_batches = 50
77
epochs = 3
78
79
device = 'cuda' if torch.cuda.is_available() else 'cpu'
80
torch.set_default_device(device)
81
82
# Creates data in default precision.
83
# The same data is used for both default and mixed precision trials below.
84
# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
85
data = [torch.randn(batch_size, in_size) for _ in range(num_batches)]
86
targets = [torch.randn(batch_size, out_size) for _ in range(num_batches)]
87
88
loss_fn = torch.nn.MSELoss().cuda()
89
90
##########################################################
91
# Default Precision
92
# -----------------
93
# Without ``torch.cuda.amp``, the following simple network executes all ops in default precision (``torch.float32``):
94
95
net = make_model(in_size, out_size, num_layers)
96
opt = torch.optim.SGD(net.parameters(), lr=0.001)
97
98
start_timer()
99
for epoch in range(epochs):
100
for input, target in zip(data, targets):
101
output = net(input)
102
loss = loss_fn(output, target)
103
loss.backward()
104
opt.step()
105
opt.zero_grad() # set_to_none=True here can modestly improve performance
106
end_timer_and_print("Default precision:")
107
108
##########################################################
109
# Adding ``torch.autocast``
110
# -------------------------
111
# Instances of `torch.autocast <https://pytorch.org/docs/stable/amp.html#autocasting>`_
112
# serve as context managers that allow regions of your script to run in mixed precision.
113
#
114
# In these regions, CUDA ops run in a ``dtype`` chosen by ``autocast``
115
# to improve performance while maintaining accuracy.
116
# See the `Autocast Op Reference <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_
117
# for details on what precision ``autocast`` chooses for each op, and under what circumstances.
118
119
for epoch in range(0): # 0 epochs, this section is for illustration only
120
for input, target in zip(data, targets):
121
# Runs the forward pass under ``autocast``.
122
with torch.autocast(device_type=device, dtype=torch.float16):
123
output = net(input)
124
# output is float16 because linear layers ``autocast`` to float16.
125
assert output.dtype is torch.float16
126
127
loss = loss_fn(output, target)
128
# loss is float32 because ``mse_loss`` layers ``autocast`` to float32.
129
assert loss.dtype is torch.float32
130
131
# Exits ``autocast`` before backward().
132
# Backward passes under ``autocast`` are not recommended.
133
# Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.
134
loss.backward()
135
opt.step()
136
opt.zero_grad() # set_to_none=True here can modestly improve performance
137
138
##########################################################
139
# Adding ``GradScaler``
140
# ---------------------
141
# `Gradient scaling <https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_
142
# helps prevent gradients with small magnitudes from flushing to zero
143
# ("underflowing") when training with mixed precision.
144
#
145
# `torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_
146
# performs the steps of gradient scaling conveniently.
147
148
# Constructs a ``scaler`` once, at the beginning of the convergence run, using default arguments.
149
# If your network fails to converge with default ``GradScaler`` arguments, please file an issue.
150
# The same ``GradScaler`` instance should be used for the entire convergence run.
151
# If you perform multiple convergence runs in the same script, each run should use
152
# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.
153
scaler = torch.cuda.amp.GradScaler()
154
155
for epoch in range(0): # 0 epochs, this section is for illustration only
156
for input, target in zip(data, targets):
157
with torch.autocast(device_type=device, dtype=torch.float16):
158
output = net(input)
159
loss = loss_fn(output, target)
160
161
# Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
162
scaler.scale(loss).backward()
163
164
# ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
165
# If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
166
# otherwise, optimizer.step() is skipped.
167
scaler.step(opt)
168
169
# Updates the scale for next iteration.
170
scaler.update()
171
172
opt.zero_grad() # set_to_none=True here can modestly improve performance
173
174
##########################################################
175
# All together: "Automatic Mixed Precision"
176
# ------------------------------------------
177
# (The following also demonstrates ``enabled``, an optional convenience argument to ``autocast`` and ``GradScaler``.
178
# If False, ``autocast`` and ``GradScaler``\ 's calls become no-ops.
179
# This allows switching between default precision and mixed precision without if/else statements.)
180
181
use_amp = True
182
183
net = make_model(in_size, out_size, num_layers)
184
opt = torch.optim.SGD(net.parameters(), lr=0.001)
185
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
186
187
start_timer()
188
for epoch in range(epochs):
189
for input, target in zip(data, targets):
190
with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
191
output = net(input)
192
loss = loss_fn(output, target)
193
scaler.scale(loss).backward()
194
scaler.step(opt)
195
scaler.update()
196
opt.zero_grad() # set_to_none=True here can modestly improve performance
197
end_timer_and_print("Mixed precision:")
198
199
##########################################################
200
# Inspecting/modifying gradients (e.g., clipping)
201
# --------------------------------------------------------
202
# All gradients produced by ``scaler.scale(loss).backward()`` are scaled. If you wish to modify or inspect
203
# the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``, you should
204
# unscale them first using `scaler.unscale_(optimizer) <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.unscale_>`_.
205
206
for epoch in range(0): # 0 epochs, this section is for illustration only
207
for input, target in zip(data, targets):
208
with torch.autocast(device_type=device, dtype=torch.float16):
209
output = net(input)
210
loss = loss_fn(output, target)
211
scaler.scale(loss).backward()
212
213
# Unscales the gradients of optimizer's assigned parameters in-place
214
scaler.unscale_(opt)
215
216
# Since the gradients of optimizer's assigned parameters are now unscaled, clips as usual.
217
# You may use the same value for max_norm here as you would without gradient scaling.
218
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)
219
220
scaler.step(opt)
221
scaler.update()
222
opt.zero_grad() # set_to_none=True here can modestly improve performance
223
224
##########################################################
225
# Saving/Resuming
226
# ----------------
227
# To save/resume Amp-enabled runs with bitwise accuracy, use
228
# `scaler.state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.state_dict>`_ and
229
# `scaler.load_state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.load_state_dict>`_.
230
#
231
# When saving, save the ``scaler`` state dict alongside the usual model and optimizer state ``dicts``.
232
# Do this either at the beginning of an iteration before any forward passes, or at the end of
233
# an iteration after ``scaler.update()``.
234
235
checkpoint = {"model": net.state_dict(),
236
"optimizer": opt.state_dict(),
237
"scaler": scaler.state_dict()}
238
# Write checkpoint as desired, e.g.,
239
# torch.save(checkpoint, "filename")
240
241
##########################################################
242
# When resuming, load the ``scaler`` state dict alongside the model and optimizer state ``dicts``.
243
# Read checkpoint as desired, for example:
244
#
245
# .. code-block::
246
#
247
# dev = torch.cuda.current_device()
248
# checkpoint = torch.load("filename",
249
# map_location = lambda storage, loc: storage.cuda(dev))
250
#
251
net.load_state_dict(checkpoint["model"])
252
opt.load_state_dict(checkpoint["optimizer"])
253
scaler.load_state_dict(checkpoint["scaler"])
254
255
##########################################################
256
# If a checkpoint was created from a run *without* Amp, and you want to resume training *with* Amp,
257
# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved ``scaler`` state, so
258
# use a fresh instance of ``GradScaler``.
259
#
260
# If a checkpoint was created from a run *with* Amp and you want to resume training *without* ``Amp``,
261
# load model and optimizer states from the checkpoint as usual, and ignore the saved ``scaler`` state.
262
263
##########################################################
264
# Inference/Evaluation
265
# --------------------
266
# ``autocast`` may be used by itself to wrap inference or evaluation forward passes. ``GradScaler`` is not necessary.
267
268
##########################################################
269
# .. _advanced-topics:
270
#
271
# Advanced topics
272
# ---------------
273
# See the `Automatic Mixed Precision Examples <https://pytorch.org/docs/stable/notes/amp_examples.html>`_ for advanced use cases including:
274
#
275
# * Gradient accumulation
276
# * Gradient penalty/double backward
277
# * Networks with multiple models, optimizers, or losses
278
# * Multiple GPUs (``torch.nn.DataParallel`` or ``torch.nn.parallel.DistributedDataParallel``)
279
# * Custom autograd functions (subclasses of ``torch.autograd.Function``)
280
#
281
# If you perform multiple convergence runs in the same script, each run should use
282
# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.
283
#
284
# If you're registering a custom C++ op with the dispatcher, see the
285
# `autocast section <https://pytorch.org/tutorials/advanced/dispatcher.html#autocast>`_
286
# of the dispatcher tutorial.
287
288
##########################################################
289
# .. _troubleshooting:
290
#
291
# Troubleshooting
292
# ---------------
293
# Speedup with Amp is minor
294
# ~~~~~~~~~~~~~~~~~~~~~~~~~
295
# 1. Your network may fail to saturate the GPU(s) with work, and is therefore CPU bound. Amp's effect on GPU performance
296
# won't matter.
297
#
298
# * A rough rule of thumb to saturate the GPU is to increase batch and/or network size(s)
299
# as much as you can without running OOM.
300
# * Try to avoid excessive CPU-GPU synchronization (``.item()`` calls, or printing values from CUDA tensors).
301
# * Try to avoid sequences of many small CUDA ops (coalesce these into a few large CUDA ops if you can).
302
# 2. Your network may be GPU compute bound (lots of ``matmuls``/convolutions) but your GPU does not have Tensor Cores.
303
# In this case a reduced speedup is expected.
304
# 3. The ``matmul`` dimensions are not Tensor Core-friendly. Make sure ``matmuls`` participating sizes are multiples of 8.
305
# (For NLP models with encoders/decoders, this can be subtle. Also, convolutions used to have similar size constraints
306
# for Tensor Core use, but for CuDNN versions 7.3 and later, no such constraints exist. See
307
# `here <https://github.com/NVIDIA/apex/issues/221#issuecomment-478084841>`_ for guidance.)
308
#
309
# Loss is inf/NaN
310
# ~~~~~~~~~~~~~~~
311
# First, check if your network fits an :ref:`advanced use case<advanced-topics>`.
312
# See also `Prefer binary_cross_entropy_with_logits over binary_cross_entropy <https://pytorch.org/docs/stable/amp.html#prefer-binary-cross-entropy-with-logits-over-binary-cross-entropy>`_.
313
#
314
# If you're confident your Amp usage is correct, you may need to file an issue, but before doing so, it's helpful to gather the following information:
315
#
316
# 1. Disable ``autocast`` or ``GradScaler`` individually (by passing ``enabled=False`` to their constructor) and see if ``infs``/``NaNs`` persist.
317
# 2. If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in ``float32``
318
# and see if ``infs``/``NaN``s persist.
319
# `The autocast docstring <https://pytorch.org/docs/stable/amp.html#torch.autocast>`_'s last code snippet
320
# shows forcing a subregion to run in ``float32`` (by locally disabling ``autocast`` and casting the subregion's inputs).
321
#
322
# Type mismatch error (may manifest as ``CUDNN_STATUS_BAD_PARAM``)
323
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
324
# ``Autocast`` tries to cover all ops that benefit from or require casting.
325
# `Ops that receive explicit coverage <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_
326
# are chosen based on numerical properties, but also on experience.
327
# If you see a type mismatch error in an ``autocast`` enabled forward region or a backward pass following that region,
328
# it's possible ``autocast`` missed an op.
329
#
330
# Please file an issue with the error backtrace. ``export TORCH_SHOW_CPP_STACKTRACES=1`` before running your script to provide
331
# fine-grained information on which backend op is failing.
332
333