Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/pinmem_nonblock.py
Views: 712
# -*- coding: utf-8 -*-1"""2A guide on good usage of ``non_blocking`` and ``pin_memory()`` in PyTorch3=========================================================================45**Author**: `Vincent Moens <https://github.com/vmoens>`_67Introduction8------------910Transferring data from the CPU to the GPU is fundamental in many PyTorch applications.11It's crucial for users to understand the most effective tools and options available for moving data between devices.12This tutorial examines two key methods for device-to-device data transfer in PyTorch:13:meth:`~torch.Tensor.pin_memory` and :meth:`~torch.Tensor.to` with the ``non_blocking=True`` option.1415What you will learn16~~~~~~~~~~~~~~~~~~~1718Optimizing the transfer of tensors from the CPU to the GPU can be achieved through asynchronous transfers and memory19pinning. However, there are important considerations:2021- Using ``tensor.pin_memory().to(device, non_blocking=True)`` can be up to twice as slow as a straightforward ``tensor.to(device)``.22- Generally, ``tensor.to(device, non_blocking=True)`` is an effective choice for enhancing transfer speed.23- While ``cpu_tensor.to("cuda", non_blocking=True).mean()`` executes correctly, attempting24``cuda_tensor.to("cpu", non_blocking=True).mean()`` will result in erroneous outputs.2526Preamble27~~~~~~~~2829The performance reported in this tutorial are conditioned on the system used to build the tutorial.30Although the conclusions are applicable across different systems, the specific observations may vary slightly31depending on the hardware available, especially on older hardware.32The primary objective of this tutorial is to offer a theoretical framework for understanding CPU to GPU data transfers.33However, any design decisions should be tailored to individual cases and guided by benchmarked throughput measurements,34as well as the specific requirements of the task at hand.3536"""3738import torch3940assert torch.cuda.is_available(), "A cuda device is required to run this tutorial"414243######################################################################44#45# This tutorial requires tensordict to be installed. If you don't have tensordict in your environment yet, install it46# by running the following command in a separate cell:47#48# .. code-block:: bash49#50# # Install tensordict with the following command51# !pip3 install tensordict52#53# We start by outlining the theory surrounding these concepts, and then move to concrete test examples of the features.54#55#56# Background57# ----------58#59# .. _pinned_memory_background:60#61# Memory management basics62# ~~~~~~~~~~~~~~~~~~~~~~~~63#64# .. _pinned_memory_memory:65#66# When one creates a CPU tensor in PyTorch, the content of this tensor needs to be placed67# in memory. The memory we talk about here is a rather complex concept worth looking at carefully.68# We distinguish two types of memory that are handled by the Memory Management Unit: the RAM (for simplicity)69# and the swap space on disk (which may or may not be the hard drive). Together, the available space in disk and RAM (physical memory)70# make up the virtual memory, which is an abstraction of the total resources available.71# In short, the virtual memory makes it so that the available space is larger than what can be found on RAM in isolation72# and creates the illusion that the main memory is larger than it actually is.73#74# In normal circumstances, a regular CPU tensor is pageable which means that it is divided in blocks called pages that75# can live anywhere in the virtual memory (both in RAM or on disk). As mentioned earlier, this has the advantage that76# the memory seems larger than what the main memory actually is.77#78# Typically, when a program accesses a page that is not in RAM, a "page fault" occurs and the operating system (OS) then brings79# back this page into RAM ("swap in" or "page in").80# In turn, the OS may have to swap out (or "page out") another page to make room for the new page.81#82# In contrast to pageable memory, a pinned (or page-locked or non-pageable) memory is a type of memory that cannot83# be swapped out to disk.84# It allows for faster and more predictable access times, but has the downside that it is more limited than the85# pageable memory (aka the main memory).86#87# .. figure:: /_static/img/pinmem/pinmem.png88# :alt:89#90# CUDA and (non-)pageable memory91# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~92#93# .. _pinned_memory_cuda_pageable_memory:94#95# To understand how CUDA copies a tensor from CPU to CUDA, let's consider the two scenarios above:96#97# - If the memory is page-locked, the device can access the memory directly in the main memory. The memory addresses are well98# defined and functions that need to read these data can be significantly accelerated.99# - If the memory is pageable, all the pages will have to be brought to the main memory before being sent to the GPU.100# This operation may take time and is less predictable than when executed on page-locked tensors.101#102# More precisely, when CUDA sends pageable data from CPU to GPU, it must first create a page-locked copy of that data103# before making the transfer.104#105# Asynchronous vs. Synchronous Operations with ``non_blocking=True`` (CUDA ``cudaMemcpyAsync``)106# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~107#108# .. _pinned_memory_async_sync:109#110# When executing a copy from a host (e.g., CPU) to a device (e.g., GPU), the CUDA toolkit offers modalities to do these111# operations synchronously or asynchronously with respect to the host.112#113# In practice, when calling :meth:`~torch.Tensor.to`, PyTorch always makes a call to114# `cudaMemcpyAsync <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1g85073372f776b4c4d5f89f7124b7bf79>`_.115# If ``non_blocking=False`` (default), a ``cudaStreamSynchronize`` will be called after each and every ``cudaMemcpyAsync``, making116# the call to :meth:`~torch.Tensor.to` blocking in the main thread.117# If ``non_blocking=True``, no synchronization is triggered, and the main thread on the host is not blocked.118# Therefore, from the host perspective, multiple tensors can be sent to the device simultaneously,119# as the thread does not need to wait for one transfer to be completed to initiate the other.120#121# .. note:: In general, the transfer is blocking on the device side (even if it isn't on the host side):122# the copy on the device cannot occur while another operation is being executed.123# However, in some advanced scenarios, a copy and a kernel execution can be done simultaneously on the GPU side.124# As the following example will show, three requirements must be met to enable this:125#126# 1. The device must have at least one free DMA (Direct Memory Access) engine. Modern GPU architectures such as Volterra,127# Tesla, or H100 devices have more than one DMA engine.128#129# 2. The transfer must be done on a separate, non-default cuda stream. In PyTorch, cuda streams can be handles using130# :class:`~torch.cuda.Stream`.131#132# 3. The source data must be in pinned memory.133#134# We demonstrate this by running profiles on the following script.135#136137import contextlib138139from torch.cuda import Stream140141142s = Stream()143144torch.manual_seed(42)145t1_cpu_pinned = torch.randn(1024**2 * 5, pin_memory=True)146t2_cpu_paged = torch.randn(1024**2 * 5, pin_memory=False)147t3_cuda = torch.randn(1024**2 * 5, device="cuda:0")148149assert torch.cuda.is_available()150device = torch.device("cuda", torch.cuda.current_device())151152153# The function we want to profile154def inner(pinned: bool, streamed: bool):155with torch.cuda.stream(s) if streamed else contextlib.nullcontext():156if pinned:157t1_cuda = t1_cpu_pinned.to(device, non_blocking=True)158else:159t2_cuda = t2_cpu_paged.to(device, non_blocking=True)160t_star_cuda_h2d_event = s.record_event()161# This operation can be executed during the CPU to GPU copy if and only if the tensor is pinned and the copy is162# done in the other stream163t3_cuda_mul = t3_cuda * t3_cuda * t3_cuda164t3_cuda_h2d_event = torch.cuda.current_stream().record_event()165t_star_cuda_h2d_event.synchronize()166t3_cuda_h2d_event.synchronize()167168169# Our profiler: profiles the `inner` function and stores the results in a .json file170def benchmark_with_profiler(171pinned,172streamed,173) -> None:174torch._C._profiler._set_cuda_sync_enabled_val(True)175wait, warmup, active = 1, 1, 2176num_steps = wait + warmup + active177rank = 0178with torch.profiler.profile(179activities=[180torch.profiler.ProfilerActivity.CPU,181torch.profiler.ProfilerActivity.CUDA,182],183schedule=torch.profiler.schedule(184wait=wait, warmup=warmup, active=active, repeat=1, skip_first=1185),186) as prof:187for step_idx in range(1, num_steps + 1):188inner(streamed=streamed, pinned=pinned)189if rank is None or rank == 0:190prof.step()191prof.export_chrome_trace(f"trace_streamed{int(streamed)}_pinned{int(pinned)}.json")192193194######################################################################195# Loading these profile traces in chrome (``chrome://tracing``) shows the following results: first, let's see196# what happens if both the arithmetic operation on ``t3_cuda`` is executed after the pageable tensor is sent to GPU197# in the main stream:198#199200benchmark_with_profiler(streamed=False, pinned=False)201202######################################################################203# .. figure:: /_static/img/pinmem/trace_streamed0_pinned0.png204# :alt:205#206# Using a pinned tensor doesn't change the trace much, both operations are still executed consecutively:207208benchmark_with_profiler(streamed=False, pinned=True)209210######################################################################211#212# .. figure:: /_static/img/pinmem/trace_streamed0_pinned1.png213# :alt:214#215# Sending a pageable tensor to GPU on a separate stream is also a blocking operation:216217benchmark_with_profiler(streamed=True, pinned=False)218219######################################################################220#221# .. figure:: /_static/img/pinmem/trace_streamed1_pinned0.png222# :alt:223#224# Only pinned tensors copies to GPU on a separate stream overlap with another cuda kernel executed on225# the main stream:226227benchmark_with_profiler(streamed=True, pinned=True)228229######################################################################230#231# .. figure:: /_static/img/pinmem/trace_streamed1_pinned1.png232# :alt:233#234# A PyTorch perspective235# ---------------------236#237# .. _pinned_memory_pt_perspective:238#239# ``pin_memory()``240# ~~~~~~~~~~~~~~~~241#242# .. _pinned_memory_pinned:243#244# PyTorch offers the possibility to create and send tensors to page-locked memory through the245# :meth:`~torch.Tensor.pin_memory` method and constructor arguments.246# CPU tensors on a machine where CUDA is initialized can be cast to pinned memory through the :meth:`~torch.Tensor.pin_memory`247# method. Importantly, ``pin_memory`` is blocking on the main thread of the host: it will wait for the tensor to be copied to248# page-locked memory before executing the next operation.249# New tensors can be directly created in pinned memory with functions like :func:`~torch.zeros`, :func:`~torch.ones` and other250# constructors.251#252# Let us check the speed of pinning memory and sending tensors to CUDA:253254255import torch256import gc257from torch.utils.benchmark import Timer258import matplotlib.pyplot as plt259260261def timer(cmd):262median = (263Timer(cmd, globals=globals())264.adaptive_autorange(min_run_time=1.0, max_run_time=20.0)265.median266* 1000267)268print(f"{cmd}: {median: 4.4f} ms")269return median270271272# A tensor in pageable memory273pageable_tensor = torch.randn(1_000_000)274275# A tensor in page-locked (pinned) memory276pinned_tensor = torch.randn(1_000_000, pin_memory=True)277278# Runtimes:279pageable_to_device = timer("pageable_tensor.to('cuda:0')")280pinned_to_device = timer("pinned_tensor.to('cuda:0')")281pin_mem = timer("pageable_tensor.pin_memory()")282pin_mem_to_device = timer("pageable_tensor.pin_memory().to('cuda:0')")283284# Ratios:285r1 = pinned_to_device / pageable_to_device286r2 = pin_mem_to_device / pageable_to_device287288# Create a figure with the results289fig, ax = plt.subplots()290291xlabels = [0, 1, 2]292bar_labels = [293"pageable_tensor.to(device) (1x)",294f"pinned_tensor.to(device) ({r1:4.2f}x)",295f"pageable_tensor.pin_memory().to(device) ({r2:4.2f}x)"296f"\npin_memory()={100*pin_mem/pin_mem_to_device:.2f}% of runtime.",297]298values = [pageable_to_device, pinned_to_device, pin_mem_to_device]299colors = ["tab:blue", "tab:red", "tab:orange"]300ax.bar(xlabels, values, label=bar_labels, color=colors)301302ax.set_ylabel("Runtime (ms)")303ax.set_title("Device casting runtime (pin-memory)")304ax.set_xticks([])305ax.legend()306307plt.show()308309# Clear tensors310del pageable_tensor, pinned_tensor311_ = gc.collect()312313######################################################################314#315# We can observe that casting a pinned-memory tensor to GPU is indeed much faster than a pageable tensor, because under316# the hood, a pageable tensor must be copied to pinned memory before being sent to GPU.317#318# However, contrary to a somewhat common belief, calling :meth:`~torch.Tensor.pin_memory()` on a pageable tensor before319# casting it to GPU should not bring any significant speed-up, on the contrary this call is usually slower than just320# executing the transfer. This makes sense, since we're actually asking Python to execute an operation that CUDA will321# perform anyway before copying the data from host to device.322#323# .. note:: The PyTorch implementation of324# `pin_memory <https://github.com/pytorch/pytorch/blob/5298acb5c76855bc5a99ae10016efc86b27949bd/aten/src/ATen/native/Memory.cpp#L58>`_325# which relies on creating a brand new storage in pinned memory through `cudaHostAlloc <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1gb65da58f444e7230d3322b6126bb4902>`_326# could be, in rare cases, faster than transitioning data in chunks as ``cudaMemcpy`` does.327# Here too, the observation may vary depending on the available hardware, the size of the tensors being sent or328# the amount of available RAM.329#330# ``non_blocking=True``331# ~~~~~~~~~~~~~~~~~~~~~332#333# .. _pinned_memory_non_blocking:334#335# As mentioned earlier, many PyTorch operations have the option of being executed asynchronously with respect to the host336# through the ``non_blocking`` argument.337#338# Here, to account accurately of the benefits of using ``non_blocking``, we will design a slightly more complex339# experiment since we want to assess how fast it is to send multiple tensors to GPU with and without calling340# ``non_blocking``.341#342343344# A simple loop that copies all tensors to cuda345def copy_to_device(*tensors):346result = []347for tensor in tensors:348result.append(tensor.to("cuda:0"))349return result350351352# A loop that copies all tensors to cuda asynchronously353def copy_to_device_nonblocking(*tensors):354result = []355for tensor in tensors:356result.append(tensor.to("cuda:0", non_blocking=True))357# We need to synchronize358torch.cuda.synchronize()359return result360361362# Create a list of tensors363tensors = [torch.randn(1000) for _ in range(1000)]364to_device = timer("copy_to_device(*tensors)")365to_device_nonblocking = timer("copy_to_device_nonblocking(*tensors)")366367# Ratio368r1 = to_device_nonblocking / to_device369370# Plot the results371fig, ax = plt.subplots()372373xlabels = [0, 1]374bar_labels = [f"to(device) (1x)", f"to(device, non_blocking=True) ({r1:4.2f}x)"]375colors = ["tab:blue", "tab:red"]376values = [to_device, to_device_nonblocking]377378ax.bar(xlabels, values, label=bar_labels, color=colors)379380ax.set_ylabel("Runtime (ms)")381ax.set_title("Device casting runtime (non-blocking)")382ax.set_xticks([])383ax.legend()384385plt.show()386387388######################################################################389# To get a better sense of what is happening here, let us profile these two functions:390391392from torch.profiler import profile, ProfilerActivity393394395def profile_mem(cmd):396with profile(activities=[ProfilerActivity.CPU]) as prof:397exec(cmd)398print(cmd)399print(prof.key_averages().table(row_limit=10))400401402######################################################################403# Let's see the call stack with a regular ``to(device)`` first:404#405406print("Call to `to(device)`", profile_mem("copy_to_device(*tensors)"))407408######################################################################409# and now the ``non_blocking`` version:410#411412print(413"Call to `to(device, non_blocking=True)`",414profile_mem("copy_to_device_nonblocking(*tensors)"),415)416417418######################################################################419# The results are without any doubt better when using ``non_blocking=True``, as all transfers are initiated simultaneously420# on the host side and only one synchronization is done.421#422# The benefit will vary depending on the number and the size of the tensors as well as depending on the hardware being423# used.424#425# .. note:: Interestingly, the blocking ``to("cuda")`` actually performs the same asynchronous device casting operation426# (``cudaMemcpyAsync``) as the one with ``non_blocking=True`` with a synchronization point after each copy.427#428# Synergies429# ~~~~~~~~~430#431# .. _pinned_memory_synergies:432#433# Now that we have made the point that data transfer of tensors already in pinned memory to GPU is faster than from434# pageable memory, and that we know that doing these transfers asynchronously is also faster than synchronously, we can435# benchmark combinations of these approaches. First, let's write a couple of new functions that will call ``pin_memory``436# and ``to(device)`` on each tensor:437#438439440def pin_copy_to_device(*tensors):441result = []442for tensor in tensors:443result.append(tensor.pin_memory().to("cuda:0"))444return result445446447def pin_copy_to_device_nonblocking(*tensors):448result = []449for tensor in tensors:450result.append(tensor.pin_memory().to("cuda:0", non_blocking=True))451# We need to synchronize452torch.cuda.synchronize()453return result454455456######################################################################457# The benefits of using :meth:`~torch.Tensor.pin_memory` are more pronounced for458# somewhat large batches of large tensors:459#460461tensors = [torch.randn(1_000_000) for _ in range(1000)]462page_copy = timer("copy_to_device(*tensors)")463page_copy_nb = timer("copy_to_device_nonblocking(*tensors)")464465tensors_pinned = [torch.randn(1_000_000, pin_memory=True) for _ in range(1000)]466pinned_copy = timer("copy_to_device(*tensors_pinned)")467pinned_copy_nb = timer("copy_to_device_nonblocking(*tensors_pinned)")468469pin_and_copy = timer("pin_copy_to_device(*tensors)")470pin_and_copy_nb = timer("pin_copy_to_device_nonblocking(*tensors)")471472# Plot473strategies = ("pageable copy", "pinned copy", "pin and copy")474blocking = {475"blocking": [page_copy, pinned_copy, pin_and_copy],476"non-blocking": [page_copy_nb, pinned_copy_nb, pin_and_copy_nb],477}478479x = torch.arange(3)480width = 0.25481multiplier = 0482483484fig, ax = plt.subplots(layout="constrained")485486for attribute, runtimes in blocking.items():487offset = width * multiplier488rects = ax.bar(x + offset, runtimes, width, label=attribute)489ax.bar_label(rects, padding=3, fmt="%.2f")490multiplier += 1491492# Add some text for labels, title and custom x-axis tick labels, etc.493ax.set_ylabel("Runtime (ms)")494ax.set_title("Runtime (pin-mem and non-blocking)")495ax.set_xticks([0, 1, 2])496ax.set_xticklabels(strategies)497plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")498ax.legend(loc="upper left", ncols=3)499500plt.show()501502del tensors, tensors_pinned503_ = gc.collect()504505506######################################################################507# Other copy directions (GPU -> CPU, CPU -> MPS)508# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~509#510# .. _pinned_memory_other_direction:511#512# Until now, we have operated under the assumption that asynchronous copies from the CPU to the GPU are safe.513# This is generally true because CUDA automatically handles synchronization to ensure that the data being accessed is514# valid at read time.515# However, this guarantee does not extend to transfers in the opposite direction, from GPU to CPU.516# Without explicit synchronization, these transfers offer no assurance that the copy will be complete at the time of517# data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage:518#519520521tensor = (522torch.arange(1, 1_000_000, dtype=torch.double, device="cuda")523.expand(100, 999999)524.clone()525)526torch.testing.assert_close(527tensor.mean(), torch.tensor(500_000, dtype=torch.double, device="cuda")528), tensor.mean()529try:530i = -1531for i in range(100):532cpu_tensor = tensor.to("cpu", non_blocking=True)533torch.testing.assert_close(534cpu_tensor.mean(), torch.tensor(500_000, dtype=torch.double)535)536print("No test failed with non_blocking")537except AssertionError:538print(f"{i}th test failed with non_blocking. Skipping remaining tests")539try:540i = -1541for i in range(100):542cpu_tensor = tensor.to("cpu", non_blocking=True)543torch.cuda.synchronize()544torch.testing.assert_close(545cpu_tensor.mean(), torch.tensor(500_000, dtype=torch.double)546)547print("No test failed with synchronize")548except AssertionError:549print(f"One test failed with synchronize: {i}th assertion!")550551552######################################################################553# The same considerations apply to copies from the CPU to non-CUDA devices, such as MPS.554# Generally, asynchronous copies to a device are safe without explicit synchronization only when the target is a555# CUDA-enabled device.556#557# In summary, copying data from CPU to GPU is safe when using ``non_blocking=True``, but for any other direction,558# ``non_blocking=True`` can still be used but the user must make sure that a device synchronization is executed before559# the data is accessed.560#561# Practical recommendations562# -------------------------563#564# .. _pinned_memory_recommendations:565#566# We can now wrap up some early recommendations based on our observations:567#568# In general, ``non_blocking=True`` will provide good throughput, regardless of whether the original tensor is or569# isn't in pinned memory.570# If the tensor is already in pinned memory, the transfer can be accelerated, but sending it to571# pin memory manually from python main thread is a blocking operation on the host, and hence will annihilate much of572# the benefit of using ``non_blocking=True`` (as CUDA does the `pin_memory` transfer anyway).573#574# One might now legitimately ask what use there is for the :meth:`~torch.Tensor.pin_memory` method.575# In the following section, we will explore further how this can be used to accelerate the data transfer even more.576#577# Additional considerations578# -------------------------579#580# .. _pinned_memory_considerations:581#582# PyTorch notoriously provides a :class:`~torch.utils.data.DataLoader` class whose constructor accepts a583# ``pin_memory`` argument.584# Considering our previous discussion on ``pin_memory``, you might wonder how the ``DataLoader`` manages to585# accelerate data transfers if memory pinning is inherently blocking.586#587# The key lies in the DataLoader's use of a separate thread to handle the transfer of data from pageable to pinned588# memory, thus preventing any blockage in the main thread.589#590# To illustrate this, we will use the TensorDict primitive from the homonymous library.591# When invoking :meth:`~tensordict.TensorDict.to`, the default behavior is to send tensors to the device asynchronously,592# followed by a single call to ``torch.device.synchronize()`` afterwards.593#594# Additionally, ``TensorDict.to()`` includes a ``non_blocking_pin`` option which initiates multiple threads to execute595# ``pin_memory()`` before proceeding with to ``to(device)``.596# This approach can further accelerate data transfers, as demonstrated in the following example.597#598#599600from tensordict import TensorDict601import torch602from torch.utils.benchmark import Timer603import matplotlib.pyplot as plt604605# Create the dataset606td = TensorDict({str(i): torch.randn(1_000_000) for i in range(1000)})607608# Runtimes609copy_blocking = timer("td.to('cuda:0', non_blocking=False)")610copy_non_blocking = timer("td.to('cuda:0')")611copy_pin_nb = timer("td.to('cuda:0', non_blocking_pin=True, num_threads=0)")612copy_pin_multithread_nb = timer("td.to('cuda:0', non_blocking_pin=True, num_threads=4)")613614# Rations615r1 = copy_non_blocking / copy_blocking616r2 = copy_pin_nb / copy_blocking617r3 = copy_pin_multithread_nb / copy_blocking618619# Figure620fig, ax = plt.subplots()621622xlabels = [0, 1, 2, 3]623bar_labels = [624"Blocking copy (1x)",625f"Non-blocking copy ({r1:4.2f}x)",626f"Blocking pin, non-blocking copy ({r2:4.2f}x)",627f"Non-blocking pin, non-blocking copy ({r3:4.2f}x)",628]629values = [copy_blocking, copy_non_blocking, copy_pin_nb, copy_pin_multithread_nb]630colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"]631632ax.bar(xlabels, values, label=bar_labels, color=colors)633634ax.set_ylabel("Runtime (ms)")635ax.set_title("Device casting runtime")636ax.set_xticks([])637ax.legend()638639plt.show()640641######################################################################642# In this example, we are transferring many large tensors from the CPU to the GPU.643# This scenario is ideal for utilizing multithreaded ``pin_memory()``, which can significantly enhance performance.644# However, if the tensors are small, the overhead associated with multithreading may outweigh the benefits.645# Similarly, if there are only a few tensors, the advantages of pinning tensors on separate threads become limited.646#647# As an additional note, while it might seem advantageous to create permanent buffers in pinned memory to shuttle648# tensors from pageable memory before transferring them to the GPU, this strategy does not necessarily expedite649# computation. The inherent bottleneck caused by copying data into pinned memory remains a limiting factor.650#651# Moreover, transferring data that resides on disk (whether in shared memory or files) to the GPU typically requires an652# intermediate step of copying the data into pinned memory (located in RAM).653# Utilizing non_blocking for large data transfers in this context can significantly increase RAM consumption,654# potentially leading to adverse effects.655#656# In practice, there is no one-size-fits-all solution.657# The effectiveness of using multithreaded ``pin_memory`` combined with ``non_blocking`` transfers depends on a658# variety of factors, including the specific system, operating system, hardware, and the nature of the tasks659# being executed.660# Here is a list of factors to check when trying to speed-up data transfers between CPU and GPU, or comparing661# throughput's across scenarios:662#663# - **Number of available cores**664#665# How many CPU cores are available? Is the system shared with other users or processes that might compete for666# resources?667#668# - **Core utilization**669#670# Are the CPU cores heavily utilized by other processes? Does the application perform other CPU-intensive tasks671# concurrently with data transfers?672#673# - **Memory utilization**674#675# How much pageable and page-locked memory is currently being used? Is there sufficient free memory to allocate676# additional pinned memory without affecting system performance? Remember that nothing comes for free, for instance677# ``pin_memory`` will consume RAM and may impact other tasks.678#679# - **CUDA Device Capabilities**680#681# Does the GPU support multiple DMA engines for concurrent data transfers? What are the specific capabilities and682# limitations of the CUDA device being used?683#684# - **Number of tensors to be sent**685#686# How many tensors are transferred in a typical operation?687#688# - **Size of the tensors to be sent**689#690# What is the size of the tensors being transferred? A few large tensors or many small tensors may not benefit from691# the same transfer program.692#693# - **System Architecture**694#695# How is the system's architecture influencing data transfer speeds (for example, bus speeds, network latency)?696#697# Additionally, allocating a large number of tensors or sizable tensors in pinned memory can monopolize a substantial698# portion of RAM.699# This reduces the available memory for other critical operations, such as paging, which can negatively impact the700# overall performance of an algorithm.701#702# Conclusion703# ----------704#705# .. _pinned_memory_conclusion:706#707# Throughout this tutorial, we have explored several critical factors that influence transfer speeds and memory708# management when sending tensors from the host to the device. We've learned that using ``non_blocking=True`` generally709# accelerates data transfers, and that :meth:`~torch.Tensor.pin_memory` can also enhance performance if implemented710# correctly. However, these techniques require careful design and calibration to be effective.711#712# Remember that profiling your code and keeping an eye on the memory consumption are essential to optimize resource713# usage and achieve the best possible performance.714#715# Additional resources716# --------------------717#718# .. _pinned_memory_resources:719#720# If you are dealing with issues with memory copies when using CUDA devices or want to learn more about721# what was discussed in this tutorial, check the following references:722#723# - `CUDA toolkit memory management doc <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html>`_;724# - `CUDA pin-memory note <https://forums.developer.nvidia.com/t/pinned-memory/268474>`_;725# - `How to Optimize Data Transfers in CUDA C/C++ <https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/>`_;726# - `tensordict doc <https://pytorch.org/tensordict/stable/index.html>`_ and `repo <https://github.com/pytorch/tensordict>`_.727#728729730