Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/amp_recipe.py
Views: 713
# -*- coding: utf-8 -*-1"""2Automatic Mixed Precision3*************************4**Author**: `Michael Carilli <https://github.com/mcarilli>`_56`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ provides convenience methods for mixed precision,7where some operations use the ``torch.float32`` (``float``) datatype and other operations8use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions,9are much faster in ``float16`` or ``bfloat16``. Other ops, like reductions, often require the dynamic10range of ``float32``. Mixed precision tries to match each op to its appropriate datatype,11which can reduce your network's runtime and memory footprint.1213Ordinarily, "automatic mixed precision training" uses `torch.autocast <https://pytorch.org/docs/stable/amp.html#torch.autocast>`_ and14`torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_ together.1516This recipe measures the performance of a simple network in default precision,17then walks through adding ``autocast`` and ``GradScaler`` to run the same network in18mixed precision with improved performance.1920You may download and run this recipe as a standalone Python script.21The only requirements are PyTorch 1.6 or later and a CUDA-capable GPU.2223Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere).24This recipe should show significant (2-3X) speedup on those architectures.25On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup.26Run ``nvidia-smi`` to display your GPU's architecture.27"""2829import torch, time, gc3031# Timing utilities32start_time = None3334def start_timer():35global start_time36gc.collect()37torch.cuda.empty_cache()38torch.cuda.reset_max_memory_allocated()39torch.cuda.synchronize()40start_time = time.time()4142def end_timer_and_print(local_msg):43torch.cuda.synchronize()44end_time = time.time()45print("\n" + local_msg)46print("Total execution time = {:.3f} sec".format(end_time - start_time))47print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))4849##########################################################50# A simple network51# ----------------52# The following sequence of linear layers and ReLUs should show a speedup with mixed precision.5354def make_model(in_size, out_size, num_layers):55layers = []56for _ in range(num_layers - 1):57layers.append(torch.nn.Linear(in_size, in_size))58layers.append(torch.nn.ReLU())59layers.append(torch.nn.Linear(in_size, out_size))60return torch.nn.Sequential(*tuple(layers)).cuda()6162##########################################################63# ``batch_size``, ``in_size``, ``out_size``, and ``num_layers`` are chosen to be large enough to saturate the GPU with work.64# Typically, mixed precision provides the greatest speedup when the GPU is saturated.65# Small networks may be CPU bound, in which case mixed precision won't improve performance.66# Sizes are also chosen such that linear layers' participating dimensions are multiples of 8,67# to permit Tensor Core usage on Tensor Core-capable GPUs (see :ref:`Troubleshooting<troubleshooting>` below).68#69# Exercise: Vary participating sizes and see how the mixed precision speedup changes.7071batch_size = 512 # Try, for example, 128, 256, 513.72in_size = 409673out_size = 409674num_layers = 375num_batches = 5076epochs = 37778device = 'cuda' if torch.cuda.is_available() else 'cpu'79torch.set_default_device(device)8081# Creates data in default precision.82# The same data is used for both default and mixed precision trials below.83# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.84data = [torch.randn(batch_size, in_size) for _ in range(num_batches)]85targets = [torch.randn(batch_size, out_size) for _ in range(num_batches)]8687loss_fn = torch.nn.MSELoss().cuda()8889##########################################################90# Default Precision91# -----------------92# Without ``torch.cuda.amp``, the following simple network executes all ops in default precision (``torch.float32``):9394net = make_model(in_size, out_size, num_layers)95opt = torch.optim.SGD(net.parameters(), lr=0.001)9697start_timer()98for epoch in range(epochs):99for input, target in zip(data, targets):100output = net(input)101loss = loss_fn(output, target)102loss.backward()103opt.step()104opt.zero_grad() # set_to_none=True here can modestly improve performance105end_timer_and_print("Default precision:")106107##########################################################108# Adding ``torch.autocast``109# -------------------------110# Instances of `torch.autocast <https://pytorch.org/docs/stable/amp.html#autocasting>`_111# serve as context managers that allow regions of your script to run in mixed precision.112#113# In these regions, CUDA ops run in a ``dtype`` chosen by ``autocast``114# to improve performance while maintaining accuracy.115# See the `Autocast Op Reference <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_116# for details on what precision ``autocast`` chooses for each op, and under what circumstances.117118for epoch in range(0): # 0 epochs, this section is for illustration only119for input, target in zip(data, targets):120# Runs the forward pass under ``autocast``.121with torch.autocast(device_type=device, dtype=torch.float16):122output = net(input)123# output is float16 because linear layers ``autocast`` to float16.124assert output.dtype is torch.float16125126loss = loss_fn(output, target)127# loss is float32 because ``mse_loss`` layers ``autocast`` to float32.128assert loss.dtype is torch.float32129130# Exits ``autocast`` before backward().131# Backward passes under ``autocast`` are not recommended.132# Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.133loss.backward()134opt.step()135opt.zero_grad() # set_to_none=True here can modestly improve performance136137##########################################################138# Adding ``GradScaler``139# ---------------------140# `Gradient scaling <https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_141# helps prevent gradients with small magnitudes from flushing to zero142# ("underflowing") when training with mixed precision.143#144# `torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_145# performs the steps of gradient scaling conveniently.146147# Constructs a ``scaler`` once, at the beginning of the convergence run, using default arguments.148# If your network fails to converge with default ``GradScaler`` arguments, please file an issue.149# The same ``GradScaler`` instance should be used for the entire convergence run.150# If you perform multiple convergence runs in the same script, each run should use151# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.152scaler = torch.cuda.amp.GradScaler()153154for epoch in range(0): # 0 epochs, this section is for illustration only155for input, target in zip(data, targets):156with torch.autocast(device_type=device, dtype=torch.float16):157output = net(input)158loss = loss_fn(output, target)159160# Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.161scaler.scale(loss).backward()162163# ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.164# If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,165# otherwise, optimizer.step() is skipped.166scaler.step(opt)167168# Updates the scale for next iteration.169scaler.update()170171opt.zero_grad() # set_to_none=True here can modestly improve performance172173##########################################################174# All together: "Automatic Mixed Precision"175# ------------------------------------------176# (The following also demonstrates ``enabled``, an optional convenience argument to ``autocast`` and ``GradScaler``.177# If False, ``autocast`` and ``GradScaler``\ 's calls become no-ops.178# This allows switching between default precision and mixed precision without if/else statements.)179180use_amp = True181182net = make_model(in_size, out_size, num_layers)183opt = torch.optim.SGD(net.parameters(), lr=0.001)184scaler = torch.cuda.amp.GradScaler(enabled=use_amp)185186start_timer()187for epoch in range(epochs):188for input, target in zip(data, targets):189with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):190output = net(input)191loss = loss_fn(output, target)192scaler.scale(loss).backward()193scaler.step(opt)194scaler.update()195opt.zero_grad() # set_to_none=True here can modestly improve performance196end_timer_and_print("Mixed precision:")197198##########################################################199# Inspecting/modifying gradients (e.g., clipping)200# --------------------------------------------------------201# All gradients produced by ``scaler.scale(loss).backward()`` are scaled. If you wish to modify or inspect202# the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``, you should203# unscale them first using `scaler.unscale_(optimizer) <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.unscale_>`_.204205for epoch in range(0): # 0 epochs, this section is for illustration only206for input, target in zip(data, targets):207with torch.autocast(device_type=device, dtype=torch.float16):208output = net(input)209loss = loss_fn(output, target)210scaler.scale(loss).backward()211212# Unscales the gradients of optimizer's assigned parameters in-place213scaler.unscale_(opt)214215# Since the gradients of optimizer's assigned parameters are now unscaled, clips as usual.216# You may use the same value for max_norm here as you would without gradient scaling.217torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)218219scaler.step(opt)220scaler.update()221opt.zero_grad() # set_to_none=True here can modestly improve performance222223##########################################################224# Saving/Resuming225# ----------------226# To save/resume Amp-enabled runs with bitwise accuracy, use227# `scaler.state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.state_dict>`_ and228# `scaler.load_state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.load_state_dict>`_.229#230# When saving, save the ``scaler`` state dict alongside the usual model and optimizer state ``dicts``.231# Do this either at the beginning of an iteration before any forward passes, or at the end of232# an iteration after ``scaler.update()``.233234checkpoint = {"model": net.state_dict(),235"optimizer": opt.state_dict(),236"scaler": scaler.state_dict()}237# Write checkpoint as desired, e.g.,238# torch.save(checkpoint, "filename")239240##########################################################241# When resuming, load the ``scaler`` state dict alongside the model and optimizer state ``dicts``.242# Read checkpoint as desired, for example:243#244# .. code-block::245#246# dev = torch.cuda.current_device()247# checkpoint = torch.load("filename",248# map_location = lambda storage, loc: storage.cuda(dev))249#250net.load_state_dict(checkpoint["model"])251opt.load_state_dict(checkpoint["optimizer"])252scaler.load_state_dict(checkpoint["scaler"])253254##########################################################255# If a checkpoint was created from a run *without* Amp, and you want to resume training *with* Amp,256# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved ``scaler`` state, so257# use a fresh instance of ``GradScaler``.258#259# If a checkpoint was created from a run *with* Amp and you want to resume training *without* ``Amp``,260# load model and optimizer states from the checkpoint as usual, and ignore the saved ``scaler`` state.261262##########################################################263# Inference/Evaluation264# --------------------265# ``autocast`` may be used by itself to wrap inference or evaluation forward passes. ``GradScaler`` is not necessary.266267##########################################################268# .. _advanced-topics:269#270# Advanced topics271# ---------------272# See the `Automatic Mixed Precision Examples <https://pytorch.org/docs/stable/notes/amp_examples.html>`_ for advanced use cases including:273#274# * Gradient accumulation275# * Gradient penalty/double backward276# * Networks with multiple models, optimizers, or losses277# * Multiple GPUs (``torch.nn.DataParallel`` or ``torch.nn.parallel.DistributedDataParallel``)278# * Custom autograd functions (subclasses of ``torch.autograd.Function``)279#280# If you perform multiple convergence runs in the same script, each run should use281# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.282#283# If you're registering a custom C++ op with the dispatcher, see the284# `autocast section <https://pytorch.org/tutorials/advanced/dispatcher.html#autocast>`_285# of the dispatcher tutorial.286287##########################################################288# .. _troubleshooting:289#290# Troubleshooting291# ---------------292# Speedup with Amp is minor293# ~~~~~~~~~~~~~~~~~~~~~~~~~294# 1. Your network may fail to saturate the GPU(s) with work, and is therefore CPU bound. Amp's effect on GPU performance295# won't matter.296#297# * A rough rule of thumb to saturate the GPU is to increase batch and/or network size(s)298# as much as you can without running OOM.299# * Try to avoid excessive CPU-GPU synchronization (``.item()`` calls, or printing values from CUDA tensors).300# * Try to avoid sequences of many small CUDA ops (coalesce these into a few large CUDA ops if you can).301# 2. Your network may be GPU compute bound (lots of ``matmuls``/convolutions) but your GPU does not have Tensor Cores.302# In this case a reduced speedup is expected.303# 3. The ``matmul`` dimensions are not Tensor Core-friendly. Make sure ``matmuls`` participating sizes are multiples of 8.304# (For NLP models with encoders/decoders, this can be subtle. Also, convolutions used to have similar size constraints305# for Tensor Core use, but for CuDNN versions 7.3 and later, no such constraints exist. See306# `here <https://github.com/NVIDIA/apex/issues/221#issuecomment-478084841>`_ for guidance.)307#308# Loss is inf/NaN309# ~~~~~~~~~~~~~~~310# First, check if your network fits an :ref:`advanced use case<advanced-topics>`.311# See also `Prefer binary_cross_entropy_with_logits over binary_cross_entropy <https://pytorch.org/docs/stable/amp.html#prefer-binary-cross-entropy-with-logits-over-binary-cross-entropy>`_.312#313# If you're confident your Amp usage is correct, you may need to file an issue, but before doing so, it's helpful to gather the following information:314#315# 1. Disable ``autocast`` or ``GradScaler`` individually (by passing ``enabled=False`` to their constructor) and see if ``infs``/``NaNs`` persist.316# 2. If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in ``float32``317# and see if ``infs``/``NaN``s persist.318# `The autocast docstring <https://pytorch.org/docs/stable/amp.html#torch.autocast>`_'s last code snippet319# shows forcing a subregion to run in ``float32`` (by locally disabling ``autocast`` and casting the subregion's inputs).320#321# Type mismatch error (may manifest as ``CUDNN_STATUS_BAD_PARAM``)322# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~323# ``Autocast`` tries to cover all ops that benefit from or require casting.324# `Ops that receive explicit coverage <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_325# are chosen based on numerical properties, but also on experience.326# If you see a type mismatch error in an ``autocast`` enabled forward region or a backward pass following that region,327# it's possible ``autocast`` missed an op.328#329# Please file an issue with the error backtrace. ``export TORCH_SHOW_CPP_STACKTRACES=1`` before running your script to provide330# fine-grained information on which backend op is failing.331332333