Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/tuning_guide.py
Views: 713
"""1Performance Tuning Guide2*************************3**Author**: `Szymon Migacz <https://github.com/szmigacz>`_45Performance Tuning Guide is a set of optimizations and best practices which can6accelerate training and inference of deep learning models in PyTorch. Presented7techniques often can be implemented by changing only a few lines of code and can8be applied to a wide range of deep learning models across all domains.910General optimizations11---------------------12"""1314###############################################################################15# Enable asynchronous data loading and augmentation16# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~17# `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_18# supports asynchronous data loading and data augmentation in separate worker19# subprocesses. The default setting for ``DataLoader`` is ``num_workers=0``,20# which means that the data loading is synchronous and done in the main process.21# As a result the main training process has to wait for the data to be available22# to continue the execution.23#24# Setting ``num_workers > 0`` enables asynchronous data loading and overlap25# between the training and data loading. ``num_workers`` should be tuned26# depending on the workload, CPU, GPU, and location of training data.27#28# ``DataLoader`` accepts ``pin_memory`` argument, which defaults to ``False``.29# When using a GPU it's better to set ``pin_memory=True``, this instructs30# ``DataLoader`` to use pinned memory and enables faster and asynchronous memory31# copy from the host to the GPU.3233###############################################################################34# Disable gradient calculation for validation or inference35# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~36# PyTorch saves intermediate buffers from all operations which involve tensors37# that require gradients. Typically gradients aren't needed for validation or38# inference.39# `torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#torch.no_grad>`_40# context manager can be applied to disable gradient calculation within a41# specified block of code, this accelerates execution and reduces the amount of42# required memory.43# `torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#torch.no_grad>`_44# can also be used as a function decorator.4546###############################################################################47# Disable bias for convolutions directly followed by a batch norm48# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~49# `torch.nn.Conv2d() <https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d>`_50# has ``bias`` parameter which defaults to ``True`` (the same is true for51# `Conv1d <https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d>`_52# and53# `Conv3d <https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d>`_54# ).55#56# If a ``nn.Conv2d`` layer is directly followed by a ``nn.BatchNorm2d`` layer,57# then the bias in the convolution is not needed, instead use58# ``nn.Conv2d(..., bias=False, ....)``. Bias is not needed because in the first59# step ``BatchNorm`` subtracts the mean, which effectively cancels out the60# effect of bias.61#62# This is also applicable to 1d and 3d convolutions as long as ``BatchNorm`` (or63# other normalization layer) normalizes on the same dimension as convolution's64# bias.65#66# Models available from `torchvision <https://github.com/pytorch/vision>`_67# already implement this optimization.6869###############################################################################70# Use parameter.grad = None instead of model.zero_grad() or optimizer.zero_grad()71# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~72# Instead of calling:73model.zero_grad()74# or75optimizer.zero_grad()7677###############################################################################78# to zero out gradients, use the following method instead:7980for param in model.parameters():81param.grad = None8283###############################################################################84# The second code snippet does not zero the memory of each individual parameter,85# also the subsequent backward pass uses assignment instead of addition to store86# gradients, this reduces the number of memory operations.87#88# Setting gradient to ``None`` has a slightly different numerical behavior than89# setting it to zero, for more details refer to the90# `documentation <https://pytorch.org/docs/master/optim.html#torch.optim.Optimizer.zero_grad>`_.91#92# Alternatively, starting from PyTorch 1.7, call ``model`` or93# ``optimizer.zero_grad(set_to_none=True)``.9495###############################################################################96# Fuse operations97# ~~~~~~~~~~~~~~~~~~~~~~~~~98# Pointwise operations such as elementwise addition, multiplication, and math99# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a100# single kernel. This fusion helps reduce memory access and kernel launch times.101# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates102# a separate kernel for each operation, which involves loading data from memory,103# executing the operation (often not the most time-consuming step), and writing104# the results back to memory.105#106# By using a fused operator, only one kernel is launched for multiple pointwise107# operations, and data is loaded and stored just once. This efficiency is108# particularly beneficial for activation functions, optimizers, and custom RNN cells etc.109#110# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler111# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple112# element-wise operations, enabling advanced fusion of eligible pointwise and reduction113# operations for optimized performance.114#115# In the simplest case fusion can be enabled by applying116# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_117# decorator to the function definition, for example:118119@torch.compile120def gelu(x):121return x * 0.5 * (1.0 + torch.erf(x / 1.41421))122123###############################################################################124# Refer to125# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_126# for more advanced use cases.127128###############################################################################129# Enable channels_last memory format for computer vision models130# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~131# PyTorch 1.5 introduced support for ``channels_last`` memory format for132# convolutional networks. This format is meant to be used in conjunction with133# `AMP <https://pytorch.org/docs/stable/amp.html>`_ to further accelerate134# convolutional neural networks with135# `Tensor Cores <https://www.nvidia.com/en-us/data-center/tensor-cores/>`_.136#137# Support for ``channels_last`` is experimental, but it's expected to work for138# standard computer vision models (e.g. ResNet-50, SSD). To convert models to139# ``channels_last`` format follow140# `Channels Last Memory Format Tutorial <https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_.141# The tutorial includes a section on142# `converting existing models <https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models>`_.143144###############################################################################145# Checkpoint intermediate buffers146# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~147# Buffer checkpointing is a technique to mitigate the memory capacity burden of148# model training. Instead of storing inputs of all layers to compute upstream149# gradients in backward propagation, it stores the inputs of a few layers and150# the others are recomputed during backward pass. The reduced memory151# requirements enables increasing the batch size that can improve utilization.152#153# Checkpointing targets should be selected carefully. The best is not to store154# large layer outputs that have small re-computation cost. The example target155# layers are activation functions (e.g. ``ReLU``, ``Sigmoid``, ``Tanh``),156# up/down sampling and matrix-vector operations with small accumulation depth.157#158# PyTorch supports a native159# `torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_160# API to automatically perform checkpointing and recomputation.161162###############################################################################163# Disable debugging APIs164# ~~~~~~~~~~~~~~~~~~~~~~165# Many PyTorch APIs are intended for debugging and should be disabled for166# regular training runs:167#168# * anomaly detection:169# `torch.autograd.detect_anomaly <https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly>`_170# or171# `torch.autograd.set_detect_anomaly(True) <https://pytorch.org/docs/stable/autograd.html#torch.autograd.set_detect_anomaly>`_172# * profiler related:173# `torch.autograd.profiler.emit_nvtx <https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.emit_nvtx>`_,174# `torch.autograd.profiler.profile <https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.profile>`_175# * autograd ``gradcheck``:176# `torch.autograd.gradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradcheck>`_177# or178# `torch.autograd.gradgradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradgradcheck>`_179#180181###############################################################################182# CPU specific optimizations183# --------------------------184185###############################################################################186# Utilize Non-Uniform Memory Access (NUMA) Controls187# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~188# NUMA or non-uniform memory access is a memory layout design used in data center machines meant to take advantage of locality of memory in multi-socket machines with multiple memory controllers and blocks. Generally speaking, all deep learning workloads, training or inference, get better performance without accessing hardware resources across NUMA nodes. Thus, inference can be run with multiple instances, each instance runs on one socket, to raise throughput. For training tasks on single node, distributed training is recommended to make each training process run on one socket.189#190# In general cases the following command executes a PyTorch script on cores on the Nth node only, and avoids cross-socket memory access to reduce memory access overhead.191#192# .. code-block:: sh193#194# numactl --cpunodebind=N --membind=N python <pytorch_script>195196###############################################################################197# More detailed descriptions can be found `here <https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html>`_.198199###############################################################################200# Utilize OpenMP201# ~~~~~~~~~~~~~~202# OpenMP is utilized to bring better performance for parallel computation tasks.203# ``OMP_NUM_THREADS`` is the easiest switch that can be used to accelerate computations. It determines number of threads used for OpenMP computations.204# CPU affinity setting controls how workloads are distributed over multiple cores. It affects communication overhead, cache line invalidation overhead, or page thrashing, thus proper setting of CPU affinity brings performance benefits. ``GOMP_CPU_AFFINITY`` or ``KMP_AFFINITY`` determines how to bind OpenMP* threads to physical processing units. Detailed information can be found `here <https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html>`_.205206###############################################################################207# With the following command, PyTorch run the task on N OpenMP threads.208#209# .. code-block:: sh210#211# export OMP_NUM_THREADS=N212213###############################################################################214# Typically, the following environment variables are used to set for CPU affinity with GNU OpenMP implementation. ``OMP_PROC_BIND`` specifies whether threads may be moved between processors. Setting it to CLOSE keeps OpenMP threads close to the primary thread in contiguous place partitions. ``OMP_SCHEDULE`` determines how OpenMP threads are scheduled. ``GOMP_CPU_AFFINITY`` binds threads to specific CPUs.215# An important tuning parameter is core pinning which prevent the threads of migrating between multiple CPUs, enhancing data location and minimizing inter core communication.216#217# .. code-block:: sh218#219# export OMP_SCHEDULE=STATIC220# export OMP_PROC_BIND=CLOSE221# export GOMP_CPU_AFFINITY="N-M"222223###############################################################################224# Intel OpenMP Runtime Library (``libiomp``)225# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~226# By default, PyTorch uses GNU OpenMP (GNU ``libgomp``) for parallel computation. On Intel platforms, Intel OpenMP Runtime Library (``libiomp``) provides OpenMP API specification support. It sometimes brings more performance benefits compared to ``libgomp``. Utilizing environment variable ``LD_PRELOAD`` can switch OpenMP library to ``libiomp``:227#228# .. code-block:: sh229#230# export LD_PRELOAD=<path>/libiomp5.so:$LD_PRELOAD231232###############################################################################233# Similar to CPU affinity settings in GNU OpenMP, environment variables are provided in ``libiomp`` to control CPU affinity settings.234# ``KMP_AFFINITY`` binds OpenMP threads to physical processing units. ``KMP_BLOCKTIME`` sets the time, in milliseconds, that a thread should wait, after completing the execution of a parallel region, before sleeping. In most cases, setting ``KMP_BLOCKTIME`` to 1 or 0 yields good performances.235# The following commands show a common settings with Intel OpenMP Runtime Library.236#237# .. code-block:: sh238#239# export KMP_AFFINITY=granularity=fine,compact,1,0240# export KMP_BLOCKTIME=1241242###############################################################################243# Switch Memory allocator244# ~~~~~~~~~~~~~~~~~~~~~~~245# For deep learning workloads, ``Jemalloc`` or ``TCMalloc`` can get better performance by reusing memory as much as possible than default ``malloc`` function. `Jemalloc <https://github.com/jemalloc/jemalloc>`_ is a general purpose ``malloc`` implementation that emphasizes fragmentation avoidance and scalable concurrency support. `TCMalloc <https://google.github.io/tcmalloc/overview.html>`_ also features a couple of optimizations to speed up program executions. One of them is holding memory in caches to speed up access of commonly-used objects. Holding such caches even after deallocation also helps avoid costly system calls if such memory is later re-allocated.246# Use environment variable ``LD_PRELOAD`` to take advantage of one of them.247#248# .. code-block:: sh249#250# export LD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD251252###############################################################################253# Use oneDNN Graph with TorchScript for inference254# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~255# oneDNN Graph can significantly boost inference performance. It fuses some compute-intensive operations such as convolution, matmul with their neighbor operations.256# In PyTorch 2.0, it is supported as a beta feature for ``Float32`` & ``BFloat16`` data-types.257# oneDNN Graph receives the model’s graph and identifies candidates for operator-fusion with respect to the shape of the example input.258# A model should be JIT-traced using an example input.259# Speed-up would then be observed after a couple of warm-up iterations for inputs with the same shape as the example input.260# The example code-snippets below are for resnet50, but they can very well be extended to use oneDNN Graph with custom models as well.261262# Only this extra line of code is required to use oneDNN Graph263torch.jit.enable_onednn_fusion(True)264265###############################################################################266# Using the oneDNN Graph API requires just one extra line of code for inference with Float32.267# If you are using oneDNN Graph, please avoid calling ``torch.jit.optimize_for_inference``.268269# sample input should be of the same shape as expected inputs270sample_input = [torch.rand(32, 3, 224, 224)]271# Using resnet50 from torchvision in this example for illustrative purposes,272# but the line below can indeed be modified to use custom models as well.273model = getattr(torchvision.models, "resnet50")().eval()274# Tracing the model with example input275traced_model = torch.jit.trace(model, sample_input)276# Invoking torch.jit.freeze277traced_model = torch.jit.freeze(traced_model)278279###############################################################################280# Once a model is JIT-traced with a sample input, it can then be used for inference after a couple of warm-up runs.281282with torch.no_grad():283# a couple of warm-up runs284traced_model(*sample_input)285traced_model(*sample_input)286# speedup would be observed after warm-up runs287traced_model(*sample_input)288289###############################################################################290# While the JIT fuser for oneDNN Graph also supports inference with ``BFloat16`` datatype,291# performance benefit with oneDNN Graph is only exhibited by machines with AVX512_BF16292# instruction set architecture (ISA).293# The following code snippets serves as an example of using ``BFloat16`` datatype for inference with oneDNN Graph:294295# AMP for JIT mode is enabled by default, and is divergent with its eager mode counterpart296torch._C._jit_set_autocast_mode(False)297298with torch.no_grad(), torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):299# Conv-BatchNorm folding for CNN-based Vision Models should be done with ``torch.fx.experimental.optimization.fuse`` when AMP is used300import torch.fx.experimental.optimization as optimization301# Please note that optimization.fuse need not be called when AMP is not used302model = optimization.fuse(model)303model = torch.jit.trace(model, (example_input))304model = torch.jit.freeze(model)305# a couple of warm-up runs306model(example_input)307model(example_input)308# speedup would be observed in subsequent runs.309model(example_input)310311312###############################################################################313# Train a model on CPU with PyTorch ``DistributedDataParallel``(DDP) functionality314# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~315# For small scale models or memory-bound models, such as DLRM, training on CPU is also a good choice. On a machine with multiple sockets, distributed training brings a high-efficient hardware resource usage to accelerate the training process. `Torch-ccl <https://github.com/intel/torch-ccl>`_, optimized with Intel(R) ``oneCCL`` (collective communications library) for efficient distributed deep learning training implementing such collectives like ``allreduce``, ``allgather``, ``alltoall``, implements PyTorch C10D ``ProcessGroup`` API and can be dynamically loaded as external ``ProcessGroup``. Upon optimizations implemented in PyTorch DDP module, ``torch-ccl`` accelerates communication operations. Beside the optimizations made to communication kernels, ``torch-ccl`` also features simultaneous computation-communication functionality.316317###############################################################################318# GPU specific optimizations319# --------------------------320321###############################################################################322# Enable Tensor cores323# ~~~~~~~~~~~~~~~~~~~~~~~324# Tensor cores are specialized hardware designed to compute matrix-matrix multiplication325# operations, primarily utilized in deep learning and AI workloads. Tensor cores have326# specific precision requirements which can be adjusted manually or via the Automatic327# Mixed Precision API.328#329# In particular, tensor operations take advantage of lower precision workloads.330# Which can be controlled via ``torch.set_float32_matmul_precision``.331# The default format is set to 'highest,' which utilizes the tensor data type.332# However, PyTorch offers alternative precision settings: 'high' and 'medium.'333# These options prioritize computational speed over numerical precision."334335###############################################################################336# Use CUDA Graphs337# ~~~~~~~~~~~~~~~~~~~~~~~338# At the time of using a GPU, work first must be launched from the CPU and339# in some cases the context switch between CPU and GPU can lead to bad resource340# utilization. CUDA graphs are a way to keep computation within the GPU without341# paying the extra cost of kernel launches and host synchronization.342343# It can be enabled using344torch.compile(m, "reduce-overhead")345# or346torch.compile(m, "max-autotune")347348###############################################################################349# Support for CUDA graph is in development, and its usage can incur in increased350# device memory consumption and some models might not compile.351352###############################################################################353# Enable cuDNN auto-tuner354# ~~~~~~~~~~~~~~~~~~~~~~~355# `NVIDIA cuDNN <https://developer.nvidia.com/cudnn>`_ supports many algorithms356# to compute a convolution. Autotuner runs a short benchmark and selects the357# kernel with the best performance on a given hardware for a given input size.358#359# For convolutional networks (other types currently not supported), enable cuDNN360# autotuner before launching the training loop by setting:361362torch.backends.cudnn.benchmark = True363###############################################################################364#365# * the auto-tuner decisions may be non-deterministic; different algorithm may366# be selected for different runs. For more details see367# `PyTorch: Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html?highlight=determinism>`_368# * in some rare cases, such as with highly variable input sizes, it's better369# to run convolutional networks with autotuner disabled to avoid the overhead370# associated with algorithm selection for each input size.371#372373###############################################################################374# Avoid unnecessary CPU-GPU synchronization375# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~376# Avoid unnecessary synchronizations, to let the CPU run ahead of the377# accelerator as much as possible to make sure that the accelerator work queue378# contains many operations.379#380# When possible, avoid operations which require synchronizations, for example:381#382# * ``print(cuda_tensor)``383# * ``cuda_tensor.item()``384# * memory copies: ``tensor.cuda()``, ``cuda_tensor.cpu()`` and equivalent385# ``tensor.to(device)`` calls386# * ``cuda_tensor.nonzero()``387# * python control flow which depends on results of operations performed on CUDA388# tensors e.g. ``if (cuda_tensor != 0).all()``389#390391###############################################################################392# Create tensors directly on the target device393# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~394# Instead of calling ``torch.rand(size).cuda()`` to generate a random tensor,395# produce the output directly on the target device:396# ``torch.rand(size, device='cuda')``.397#398# This is applicable to all functions which create new tensors and accept399# ``device`` argument:400# `torch.rand() <https://pytorch.org/docs/stable/generated/torch.rand.html#torch.rand>`_,401# `torch.zeros() <https://pytorch.org/docs/stable/generated/torch.zeros.html#torch.zeros>`_,402# `torch.full() <https://pytorch.org/docs/stable/generated/torch.full.html#torch.full>`_403# and similar.404405###############################################################################406# Use mixed precision and AMP407# ~~~~~~~~~~~~~~~~~~~~~~~~~~~408# Mixed precision leverages409# `Tensor Cores <https://www.nvidia.com/en-us/data-center/tensor-cores/>`_410# and offers up to 3x overall speedup on Volta and newer GPU architectures. To411# use Tensor Cores AMP should be enabled and matrix/tensor dimensions should412# satisfy requirements for calling kernels that use Tensor Cores.413#414# To use Tensor Cores:415#416# * set sizes to multiples of 8 (to map onto dimensions of Tensor Cores)417#418# * see419# `Deep Learning Performance Documentation420# <https://docs.nvidia.com/deeplearning/performance/index.html#optimizing-performance>`_421# for more details and guidelines specific to layer type422# * if layer size is derived from other parameters rather than fixed, it can423# still be explicitly padded e.g. vocabulary size in NLP models424#425# * enable AMP426#427# * Introduction to Mixed Precision Training and AMP:428# `video <https://www.youtube.com/watch?v=jF4-_ZK_tyc&feature=youtu.be>`_,429# `slides <https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf>`_430# * native PyTorch AMP is available starting from PyTorch 1.6:431# `documentation <https://pytorch.org/docs/stable/amp.html>`_,432# `examples <https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples>`_,433# `tutorial <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_434#435#436437###############################################################################438# Preallocate memory in case of variable input length439# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~440# Models for speech recognition or for NLP are often trained on input tensors441# with variable sequence length. Variable length can be problematic for PyTorch442# caching allocator and can lead to reduced performance or to unexpected443# out-of-memory errors. If a batch with a short sequence length is followed by444# an another batch with longer sequence length, then PyTorch is forced to445# release intermediate buffers from previous iteration and to re-allocate new446# buffers. This process is time consuming and causes fragmentation in the447# caching allocator which may result in out-of-memory errors.448#449# A typical solution is to implement preallocation. It consists of the450# following steps:451#452# #. generate a (usually random) batch of inputs with maximum sequence length453# (either corresponding to max length in the training dataset or to some454# predefined threshold)455# #. execute a forward and a backward pass with the generated batch, do not456# execute an optimizer or a learning rate scheduler, this step preallocates457# buffers of maximum size, which can be reused in subsequent458# training iterations459# #. zero out gradients460# #. proceed to regular training461#462463###############################################################################464# Distributed optimizations465# -------------------------466467###############################################################################468# Use efficient data-parallel backend469# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~470# PyTorch has two ways to implement data-parallel training:471#472# * `torch.nn.DataParallel <https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel>`_473# * `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_474#475# ``DistributedDataParallel`` offers much better performance and scaling to476# multiple-GPUs. For more information refer to the477# `relevant section of CUDA Best Practices <https://pytorch.org/docs/stable/notes/cuda.html#use-nn-parallel-distributeddataparallel-instead-of-multiprocessing-or-nn-dataparallel>`_478# from PyTorch documentation.479480###############################################################################481# Skip unnecessary all-reduce if training with ``DistributedDataParallel`` and gradient accumulation482# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~483# By default484# `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_485# executes gradient all-reduce after every backward pass to compute the average486# gradient over all workers participating in the training. If training uses487# gradient accumulation over N steps, then all-reduce is not necessary after488# every training step, it's only required to perform all-reduce after the last489# call to backward, just before the execution of the optimizer.490#491# ``DistributedDataParallel`` provides492# `no_sync() <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync>`_493# context manager which disables gradient all-reduce for particular iteration.494# ``no_sync()`` should be applied to first ``N-1`` iterations of gradient495# accumulation, the last iteration should follow the default execution and496# perform the required gradient all-reduce.497498###############################################################################499# Match the order of layers in constructors and during the execution if using ``DistributedDataParallel(find_unused_parameters=True)``500# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~501# `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_502# with ``find_unused_parameters=True`` uses the order of layers and parameters503# from model constructors to build buckets for ``DistributedDataParallel``504# gradient all-reduce. ``DistributedDataParallel`` overlaps all-reduce with the505# backward pass. All-reduce for a particular bucket is asynchronously triggered506# only when all gradients for parameters in a given bucket are available.507#508# To maximize the amount of overlap, the order in model constructors should509# roughly match the order during the execution. If the order doesn't match, then510# all-reduce for the entire bucket waits for the gradient which is the last to511# arrive, this may reduce the overlap between backward pass and all-reduce,512# all-reduce may end up being exposed, which slows down the training.513#514# ``DistributedDataParallel`` with ``find_unused_parameters=False`` (which is515# the default setting) relies on automatic bucket formation based on order of516# operations encountered during the backward pass. With517# ``find_unused_parameters=False`` it's not necessary to reorder layers or518# parameters to achieve optimal performance.519520###############################################################################521# Load-balance workload in a distributed setting522# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~523# Load imbalance typically may happen for models processing sequential data524# (speech recognition, translation, language models etc.). If one device525# receives a batch of data with sequence length longer than sequence lengths for526# the remaining devices, then all devices wait for the worker which finishes527# last. Backward pass functions as an implicit synchronization point in a528# distributed setting with529# `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_530# backend.531#532# There are multiple ways to solve the load balancing problem. The core idea is533# to distribute workload over all workers as uniformly as possible within each534# global batch. For example Transformer solves imbalance by forming batches with535# approximately constant number of tokens (and variable number of sequences in a536# batch), other models solve imbalance by bucketing samples with similar537# sequence length or even by sorting dataset by sequence length.538539540