Path: blob/main/intermediate_source/intermediate_data_loading_tutorial.py
6463 views
# -*- coding: utf-8 -*-1"""2Data Loading Optimization in PyTorch3==============================================45**Authors**: `Divyansh Khanna <https://github.com/divyanshk>`_, `Ramanish Singh <https://github.com/ramanishsingh>`_67.. grid:: 289.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn10:class-card: card-prerequisites1112* How to optimize DataLoader configuration for maximum throughput13* Best practices for ``batch_size``, ``num_workers``, and ``pin_memory``14* Advanced techniques for overlapping data transfers with GPU compute15* Configuring shared memory strategies and handling ``/dev/shm`` issues1617.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites18:class-card: card-prerequisites1920* PyTorch v2.0+21* Basic understanding of PyTorch DataLoader22* (Optional) A CUDA-capable GPU for GPU-specific optimizations2324Introduction25------------2627Data loading is often a critical bottleneck in deep learning pipelines. While28GPUs can process batches extremely quickly, inefficient data loading can leave29expensive hardware idle, waiting for the next batch of data. This tutorial30covers best practices and some techniques for optimizing your data loading configuration to31maximize training throughput.3233We'll explore the key parameters of PyTorch's DataLoader and provide practical34guidance on tuning them for your specific workload. Rather than showing each35optimization in isolation, we'll build up from a baseline training loop and36progressively apply optimizations, measuring the cumulative speedup at each37step.38"""3940import time4142import torch43import torch.nn as nn44from torch.utils.data import DataLoader, Dataset4546device = torch.device("cuda" if torch.cuda.is_available() else "cpu")47print(f"Using device: {device}")4849# Set a fixed seed for reproducibility50torch.manual_seed(42)5152######################################################################53# Creating a Sample Dataset54# -------------------------55#56# First, let's create a simple dataset that simulates expensive57# transformations. This will help us demonstrate the impact of58# various DataLoader configurations.596061class SyntheticDataset(Dataset):62"""A synthetic dataset that simulates expensive data transformations."""6364def __init__(self, size=10000, feature_dim=224, transform_delay=0.001):65self.size = size66self.feature_dim = feature_dim67self.transform_delay = transform_delay6869def __len__(self):70return self.size7172def __getitem__(self, idx):73# Generate data lazily to avoid pre-allocating large tensors74data = torch.randn(3, self.feature_dim, self.feature_dim)75label = torch.randint(0, 10, (1,)).item()76if self.transform_delay > 0:77time.sleep(self.transform_delay)78return data, label798081class SyntheticDatasetBatched(Dataset):82"""Same as SyntheticDataset but with __getitems__ for batched fetching."""8384def __init__(self, size=10000, feature_dim=224, transform_delay=0.001):85self.size = size86self.feature_dim = feature_dim87self.transform_delay = transform_delay8889def __len__(self):90return self.size9192def __getitem__(self, idx):93data = torch.randn(3, self.feature_dim, self.feature_dim)94label = torch.randint(0, 10, (1,)).item()95if self.transform_delay > 0:96time.sleep(self.transform_delay)97return data, label9899def __getitems__(self, indices):100"""Fetch multiple items at once — enables vectorized generation.101102Instead of N individual __getitem__ calls (each with its own103overhead), this generates the entire batch in one shot using104vectorized tensor operations.105"""106n = len(indices)107# Vectorized generation: one call instead of N individual ones108data = torch.randn(n, 3, self.feature_dim, self.feature_dim)109labels = torch.randint(0, 10, (n,))110# Simulate batch-level I/O: one sleep for the whole batch,111# not one per sample (e.g., one DB query for N rows)112if self.transform_delay > 0:113time.sleep(self.transform_delay)114return [(data[i], labels[i].item()) for i in range(n)]115116117######################################################################118# Shared Training Infrastructure119# ------------------------------120#121# To measure the real-world impact of each optimization, we define a122# reusable training loop that accepts a DataLoader and returns timing123# and loss. This avoids duplicating the training loop for every124# benchmark.125#126# We use a **small dataset (500 samples)** with a **high transform127# delay (5ms)** to ensure the pipeline remains data-bound throughout.128# The small dataset means short epochs (16 batches each), so we run129# many epochs — making persistent_workers' benefit visible across130# epoch boundaries.131132benchmark_dataset = SyntheticDataset(size=512, feature_dim=224, transform_delay=0.005)133134135class SmallTransformerModel(nn.Module):136137def __init__(self):138super().__init__()139self.features = nn.Sequential(140nn.Conv2d(3, 32, kernel_size=7, stride=4, padding=3),141nn.ReLU(),142nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),143nn.ReLU(),144nn.AdaptiveAvgPool2d((7, 7)),145)146encoder_layer = nn.TransformerEncoderLayer(147d_model=64, nhead=4, dim_feedforward=128, batch_first=True148)149self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)150self.classifier = nn.Linear(64, 10)151152def forward(self, x):153x = self.features(x) # (B, 64, 7, 7)154B, C, H, W = x.shape155x = x.view(B, C, H * W).permute(0, 2, 1) # (B, 49, 64)156x = self.transformer(x) # (B, 49, 64)157x = x.mean(dim=1) # (B, 64)158return self.classifier(x)159160161def create_model():162"""Create a conv+transformer model for benchmarking."""163return SmallTransformerModel().to(device)164165166def train_and_benchmark(loader, max_batches=160, epochs=10, prefetch_device=None):167"""Train a model over multiple epochs and return elapsed time and average loss.168169Running multiple epochs (10) with a small dataset ensures many epoch170boundaries, making persistent_workers' startup savings visible.171172Args:173loader: A DataLoader to iterate over.174max_batches: Maximum total number of batches to process across all epochs.175epochs: Number of epochs to iterate (re-iterates the loader each epoch).176prefetch_device: If set, wraps the loader in a DataPrefetcher each epoch177for overlapping H2D transfers. Data arrives already on device.178179Returns:180Tuple of (elapsed_seconds, average_loss).181"""182model = create_model()183optimizer = torch.optim.SGD(model.parameters(), lr=0.01)184criterion = nn.CrossEntropyLoss()185186start_time = time.perf_counter()187total_loss = 0.0188num_batches = 0189190for epoch in range(epochs):191if prefetch_device is not None:192data_iter = DataPrefetcher(loader, prefetch_device)193else:194data_iter = loader195196for data, labels in data_iter:197if prefetch_device is None:198data = data.to(device, non_blocking=True)199labels = labels.to(device, non_blocking=True)200201output = model(data)202loss = criterion(output, labels)203204optimizer.zero_grad()205loss.backward()206optimizer.step()207208total_loss += loss.item()209num_batches += 1210211if num_batches >= max_batches:212break213if num_batches >= max_batches:214break215216if torch.cuda.is_available():217torch.cuda.synchronize()218elapsed = time.perf_counter() - start_time219220return elapsed, total_loss / num_batches221222223######################################################################224# Baseline Training Loop225# ----------------------226#227# Our starting point: a simple DataLoader with no multiprocessing,228# no pinned memory, and default settings. This establishes the229# performance floor we'll improve upon.230231baseline_loader = DataLoader(232benchmark_dataset,233batch_size=32,234shuffle=True,235num_workers=0,236pin_memory=False,237)238239print("\n=== Progressive Optimization Results ===")240print("\nBaseline (num_workers=0, pin_memory=False):")241baseline_time, baseline_loss = train_and_benchmark(baseline_loader)242print(f" Time: {baseline_time:.4f}s | Loss: {baseline_loss:.4f}")243prev_time = baseline_time244245######################################################################246# Batch Size Optimization247# -----------------------248#249# The ``batch_size`` parameter controls how many samples are processed250# together. Choosing the right batch size involves balancing several factors:251#252# **Memory Considerations:**253#254# - Larger batch sizes require more GPU memory for storing inputs,255# activations, and gradients256# - Out-of-memory (OOM) errors are common with large batch sizes257# - Moderate batch sizes (32-128) often provide the best balance258#259# **Training Dynamics:**260#261# - Batch size changes affect the effective learning rate, typically requiring tuning262# - Larger batches provide more stable gradient estimates but may263# generalize differently264#265# .. note::266# When changing batch size, remember to tune your optimizer parameters,267# especially the learning rate schedule, unless you're doing inference268#269# Since batch size is model-dependent (not a "just add it" optimization),270# we benchmark it in isolation rather than folding it into the progressive271# optimization chain.272273# Example: Testing different batch sizes274batch_dataset = SyntheticDataset(size=1000, transform_delay=0)275276277def benchmark_batch_size(batch_size, num_batches=10):278"""Benchmark data loading with a specific batch size."""279loader = DataLoader(batch_dataset, batch_size=batch_size, shuffle=True)280start = time.perf_counter()281for i, (data, labels) in enumerate(loader):282if i >= num_batches:283break284data = data.to(device, non_blocking=True)285_ = data.sum()286if torch.cuda.is_available():287torch.cuda.synchronize()288elapsed = time.perf_counter() - start289return elapsed290291292# Benchmark different batch sizes293print("\nBatch size comparison (isolated benchmark):")294for bs in [16, 32, 64, 128]:295elapsed = benchmark_batch_size(bs)296print(f" Batch size {bs:3d}: {elapsed:.4f}s for 10 batches")297298######################################################################299# Number of Workers (``num_workers``)300# -----------------------------------301#302# The ``num_workers`` parameter controls how many subprocesses are used303# for data loading. This is crucial for parallelizing expensive data304# transformations.305#306# **How it works:**307#308# - Each worker maintains a queue of batches (controlled by ``prefetch_factor``)309# - Workers prepare batches in parallel and transfer them to the main process310# - If ``in_order=True`` (default), batches are returned in order311#312# **When to increase ``num_workers``:**313#314# - When transforms are computationally expensive (augmentations, decoding)315# - When data is loaded from slow storage (network drives, HDD)316# - When you observe GPU idle time due to data loading317#318# **When ``num_workers=0`` might be faster:**319#320# - When transforms are cheap (simple tensor operations)321# - When data is already in memory322# - The overhead of inter-process communication (IPC) exceeds the323# parallelization benefits324#325# .. note::326# Finding the optimal ``num_workers`` requires tuning: increase workers327# until throughput plateaus. Too many workers waste CPU328# memory (each worker holds its own copy of the dataset object and329# prefetched batches) and can cause ``/dev/shm`` exhaustion. A good330# starting point is 2-4 workers per GPU; profile with different values331# to find the sweet spot for your workload.332#333# Let's add ``num_workers=4`` and ``prefetch_factor=2`` to our training334# loop and measure the improvement:335336workers_loader = DataLoader(337benchmark_dataset,338batch_size=32,339shuffle=True,340num_workers=4,341prefetch_factor=2,342pin_memory=False,343)344345print("\n+ num_workers=4, prefetch_factor=2:")346workers_time, workers_loss = train_and_benchmark(workers_loader)347print(f" Time: {workers_time:.4f}s | Loss: {workers_loss:.4f}")348print(349f" Speedup vs baseline: {baseline_time / workers_time:.2f}x | vs previous: {prev_time / workers_time:.2f}x"350)351prev_time = workers_time352353######################################################################354# Understanding ``pin_memory``355# ----------------------------356#357# The ``pin_memory`` parameter enables faster CPU-to-GPU data transfers358# by using page-locked (pinned) memory.359#360# **How pinned memory works:**361#362# - Pinned memory cannot be swapped to disk by the OS363# - This enables faster DMA (Direct Memory Access) transfers to GPU364# - The CPU-to-GPU transfer can happen asynchronously365#366# **Best practices:**367#368# 1. Use ``pin_memory=True`` in the DataLoader (recommended approach)369# 2. Combine with ``non_blocking=True`` when moving data to GPU370# 3. Avoid manually calling ``tensor.pin_memory()`` followed by371# ``.to(device, non_blocking=True)`` - this is slower because372# ``pin_memory()`` is blocking373#374# **The safe pattern:**375#376# .. code-block:: python377#378# # Recommended: Let DataLoader handle pinning379# loader = DataLoader(dataset, pin_memory=True)380# for data, labels in loader:381# data = data.to(device, non_blocking=True)382# labels = labels.to(device, non_blocking=True)383#384# .. seealso::385# For more details, see the386# `pin_memory tutorial <https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html>`_387#388# Let's add ``pin_memory=True`` to our configuration:389390pinmem_loader = DataLoader(391benchmark_dataset,392batch_size=32,393shuffle=True,394num_workers=4,395prefetch_factor=2,396pin_memory=torch.cuda.is_available(),397)398399if torch.cuda.is_available():400print("\n+ pin_memory=True:")401pinmem_time, pinmem_loss = train_and_benchmark(pinmem_loader)402print(f" Time: {pinmem_time:.4f}s | Loss: {pinmem_loss:.4f}")403print(404f" Speedup vs baseline: {baseline_time / pinmem_time:.2f}x | vs previous: {prev_time / pinmem_time:.2f}x"405)406print(407" (pin_memory benefit is modest here because CPU transform time dominates H2D transfer)"408)409prev_time = pinmem_time410else:411print("\n+ pin_memory: skipped (CUDA not available)")412pinmem_time = workers_time413414######################################################################415# Persistent Workers416# ------------------417#418# By default, worker processes are shut down and restarted between419# epochs. This incurs startup overhead (importing modules, forking420# processes, re-initializing datasets) on every epoch boundary.421#422# Setting ``persistent_workers=True`` keeps the workers alive across423# epochs, eliminating this repeated startup cost.424#425# **When it helps most:**426#427# - Training for many epochs on smaller datasets428# - When dataset ``__init__`` is expensive (e.g., loading metadata)429# - When combined with high ``num_workers``430#431# Let's compare with and without persistent workers over multiple epochs:432433non_persistent_loader = DataLoader(434benchmark_dataset,435batch_size=32,436shuffle=True,437num_workers=4,438prefetch_factor=2,439pin_memory=torch.cuda.is_available(),440persistent_workers=False,441)442443persistent_loader = DataLoader(444benchmark_dataset,445batch_size=32,446shuffle=True,447num_workers=4,448prefetch_factor=2,449pin_memory=torch.cuda.is_available(),450persistent_workers=True,451)452453print("\n+ persistent_workers=True (10 epochs):")454non_persistent_time, _ = train_and_benchmark(non_persistent_loader)455persistent_time, persistent_loss = train_and_benchmark(persistent_loader)456print(f" Without persistent_workers: {non_persistent_time:.4f}s")457print(f" With persistent_workers: {persistent_time:.4f}s")458print(459f" Speedup vs baseline: {baseline_time / persistent_time:.2f}x | vs previous: {prev_time / persistent_time:.2f}x"460)461prev_time = persistent_time462463######################################################################464# Overlapping H2D Transfer with GPU Compute465# ---------------------------------------------------466#467# For maximum throughput, you can overlap Host-to-Device (H2D) data468# transfers with GPU computation. This ensures the GPU is never idle469# waiting for data.470#471# The idea is to prefetch the next batch to GPU while the current batch472# is being processed.473#474# .. note::475# The DataPrefetcher shows its greatest benefit when H2D transfer476# time overlaps meaningfully with GPU compute. If data loading is477# already fast, the stream synchronization overhead may exceed the benefit.478479480class DataPrefetcher:481"""Prefetches data to GPU while previous batch is being processed."""482483def __init__(self, loader, device):484self.loader = iter(loader)485self.device = device486self.stream = torch.cuda.Stream() if torch.cuda.is_available() else None487self.next_data = None488self.next_labels = None489self.preload()490491def preload(self):492try:493self.next_data, self.next_labels = next(self.loader)494except StopIteration:495self.next_data = None496self.next_labels = None497return498499if self.stream is not None:500with torch.cuda.stream(self.stream):501self.next_data = self.next_data.to(self.device, non_blocking=True)502self.next_labels = self.next_labels.to(self.device, non_blocking=True)503504def __iter__(self):505return self506507def __next__(self):508if self.stream is not None:509torch.cuda.current_stream().wait_stream(self.stream)510511data = self.next_data512labels = self.next_labels513514if data is None:515raise StopIteration516517# Ensure tensors are ready518if self.stream is not None:519data.record_stream(torch.cuda.current_stream())520labels.record_stream(torch.cuda.current_stream())521522self.preload()523return data, labels524525526# Integrate prefetcher into the training loop.527if torch.cuda.is_available():528print("\n+ DataPrefetcher (overlapping H2D transfer):")529prefetch_time, prefetch_loss = train_and_benchmark(530persistent_loader, prefetch_device=device531)532print(f" Time: {prefetch_time:.4f}s | Loss: {prefetch_loss:.4f}")533print(534f" Speedup vs baseline: {baseline_time / prefetch_time:.2f}x | vs previous: {prev_time / prefetch_time:.2f}x"535)536prev_time = prefetch_time537else:538print("\n+ DataPrefetcher: skipped (CUDA not available)")539prefetch_time = persistent_time540541######################################################################542# Dataset-Level Optimization: ``__getitems__``543# --------------------------------------------544#545# Beyond tuning DataLoader parameters, you can optimize the dataset546# itself. PyTorch's DataLoader supports a batched fetching protocol via547# ``__getitems__``: if your dataset defines this method, the fetcher548# calls it once with a list of indices instead of calling ``__getitem__``549# repeatedly for each sample.550#551# **How it works:**552#553# - The default fetcher does: ``[dataset[idx] for idx in batch_indices]``554# - With ``__getitems__``: ``dataset.__getitems__(batch_indices)``555#556# **When this helps:**557#558# - When per-sample overhead is significant (e.g., opening connections,559# parsing headers, acquiring locks)560# - When data can be fetched in bulk more efficiently (e.g., one SQL query561# for N rows instead of N queries, or vectorized tensor generation)562# - When the transform has a fixed setup cost that can be amortized563# across the batch564#565# **Expected signature:**566#567# .. code-block:: python568#569# def __getitems__(self, indices: list[int]) -> list:570# # Fetch all items at once and return as a list571# ...572#573# Our ``SyntheticDatasetBatched`` implements ``__getitems__`` to generate574# the entire batch in one vectorized call (with a single amortized delay)575# rather than N individual calls each with their own delay.576# Let's add this to our cumulative configuration:577578benchmark_dataset_batched = SyntheticDatasetBatched(579size=512, feature_dim=224, transform_delay=0.005580)581582batched_loader = DataLoader(583benchmark_dataset_batched,584batch_size=32,585shuffle=True,586num_workers=4,587prefetch_factor=2,588pin_memory=torch.cuda.is_available(),589persistent_workers=True,590)591592print("\n+ __getitems__ (batched dataset fetching):")593batched_time, batched_loss = train_and_benchmark(batched_loader)594print(f" Time: {batched_time:.4f}s | Loss: {batched_loss:.4f}")595print(596f" Speedup vs baseline: {baseline_time / batched_time:.2f}x | vs previous: {prev_time / batched_time:.2f}x"597)598prev_time = batched_time599600######################################################################601# ``in_order`` parameter602# --------------------------603#604# By default (``in_order=True``), the DataLoader returns batches in605# the same order as the dataset indices. This requires caching batches606# that arrive out of order from workers.607#608# **When to consider ``in_order=False``:**609#610# - When you don't need deterministic ordering (e.g., not checkpointing)611# - When you observe training spikes due to batch caching612# - When maximizing throughput is more important than reproducibility613#614# .. note::615# ``in_order=False`` might not increase average throughput, but it616# can reduce variance and eliminate occasional slow batches caused617# by head-of-line blocking when one worker is slower than others.618619######################################################################620# Snapshot Frequency (``snapshot_every_n_steps``)621# -----------------------------------------------622#623# When using torchdata's StatefulDataLoader (for checkpointing), the624# ``snapshot_every_n_steps`` parameter controls how often the625# DataLoader state is saved.626#627# **Trade-offs:**628#629# - **Higher frequency (smaller n):** More overhead, but less data loss630# on job failure631# - **Lower frequency (larger n):** Less overhead, but more replayed632# samples on recovery633#634# Choose based on your fault tolerance requirements and the cost of635# reprocessing data.636637######################################################################638# Shared Memory and ``set_sharing_strategy``639# ------------------------------------------640#641# When using multiprocessing with ``num_workers > 0``, PyTorch needs to642# transfer tensors between worker processes and the main process. The643# sharing strategy determines how this is done.644#645# **Available Strategies:**646#647# PyTorch provides two sharing strategies via648# ``torch.multiprocessing.set_sharing_strategy()``:649#650# 1. **file_descriptor** (default on most systems)651#652# - Uses file descriptors to share memory653# - Limited by system's open file descriptor limit (``ulimit -n``)654# - More efficient for small tensors655#656# 2. **file_system**657#658# - Uses shared memory files in ``/dev/shm``659# - Not limited by file descriptor count660# - Better for large numbers of tensors661# - Low transform costs662663######################################################################664# **How to Change the Strategy:**665#666# .. code-block:: python667#668# import torch.multiprocessing as mp669#670# # Switch to file_system strategy671# # Must be called before creating any DataLoader workers672# mp.set_sharing_strategy('file_system')673#674# **Choosing the Right Strategy:**675#676# +-------------------+---------------------------+---------------------------+677# | Scenario | Recommended Strategy | Reason |678# +===================+===========================+===========================+679# | Many small tensors| file_descriptor (default) | Lower overhead per tensor |680# +-------------------+---------------------------+---------------------------+681# | Few large tensors | file_system | Avoids fd limits |682# +-------------------+---------------------------+---------------------------+683# | High num_workers | file_system | Avoids fd exhaustion |684# +-------------------+---------------------------+---------------------------+685#686# .. warning::687# ``set_sharing_strategy()`` must be called **before** creating any688# DataLoader with ``num_workers > 0``. Changing it afterward has no689# effect on existing workers.690691######################################################################692# Handling Insufficient Shared Memory (``/dev/shm``)693# --------------------------------------------------694#695# When using ``num_workers > 0``, PyTorch uses shared memory (``/dev/shm``)696# to efficiently pass data between worker processes and the main process.697# If you encounter errors like:698#699# .. code-block:: text700#701# RuntimeError: unable to open shared memory object </torch_xxx>702# ERROR: Unexpected bus error encountered in worker703#704# This typically means you've exhausted the shared memory allocation.705#706# **Solutions:**707#708# **1. Increase /dev/shm size (if you can)**709#710# **2. Reduce memory pressure from DataLoader:**711#712# .. code-block:: python713#714# # Reduce number of workers715# DataLoader(dataset, num_workers=2) # Instead of 8+716#717# # Reduce prefetch factor718# DataLoader(dataset, num_workers=4, prefetch_factor=1) # Instead of 2719#720# # Use smaller batch sizes721# DataLoader(dataset, batch_size=16) # Smaller batches = less shm722#723# **3. Switch sharing strategy:**724#725# .. code-block:: python726#727# import torch.multiprocessing as mp728# mp.set_sharing_strategy('file_system')729#730# **4. Clean up leaked shared memory:**731#732# .. code-block:: bash733#734# # List shared memory segments735# ls -la /dev/shm/736#737# # Remove orphaned PyTorch segments (be careful!)738# rm /dev/shm/torch_*739#740# .. note::741# Shared memory leaks can occur if worker processes crash without742# proper cleanup.743#744745######################################################################746# Final Summary747# -------------748#749# Here's the cumulative effect of each optimization we applied to750# our training loop. Each row includes all optimizations from previous751# rows:752#753# .. rst-class:: summary-table754#755# .. list-table::756# :header-rows: 1757# :widths: 55 20 20758#759# * - Configuration760# - vs Baseline761# - vs Previous762# * - Baseline (num_workers=0, no pinning)763# - 1.00x764# - —765# * - \+ num_workers=4, prefetch_factor=2766# - ~2.7x767# - ~2.7x768# * - \+ pin_memory=True769# - ~2.8x770# - ~1.0x771# * - \+ persistent_workers=True772# - ~3.7x773# - ~1.3x774# * - \+ DataPrefetcher (H2D overlap)775# - ~3.6x776# - ~1.0x777# * - \+ __getitems__ (batched fetching)778# - ~10x779# - ~2.9x780#781# .. note::782# These results are based on our benchmark dataset.783# Actual speedups will vary depending on your specific784# workload, hardware, dataset size, and transform complexity.785786######################################################################787# Summary and Best Practices788# --------------------------789#790# 1. **Start with moderate batch sizes** (32-128) and scale up if memory791# allows.792#793# 2. **Use ``num_workers > 0``** when transforms are expensive. Start with794# 2-4 workers and increase based on memory capacity. Higher is not always better.795#796# 3. **Enable ``pin_memory=True``** when using an accelerator.797#798# 4. **Use ``persistent_workers=True``** to avoid worker restart overhead799# between epochs.800#801# 5. **Profile your pipeline** with to identify CPU bottlenecks during802# dataset access, transformations, etc.803#804# 6. **Implement data prefetching** for GPU workloads to overlap data805# transfer with computation.806#807# 7. **Use ``file_system`` sharing strategy** when hitting file descriptor limits.808#809810######################################################################811# Conclusion812# ----------813#814# In this tutorial, we learned how to progressively optimize a PyTorch815# data loading pipeline — from a naive single-process baseline to a816# fully optimized configuration using multiprocessing workers, pinned817# memory, persistent workers, CUDA stream-based prefetching, and batched818# dataset fetching with ``__getitems__``. Each optimization targets a819# different bottleneck, and together they can yield an order-of-magnitude820# improvement in throughput. These should be considered best practices821# and performance is dependent on the specific workload.822823######################################################################824# Additional Resources825# --------------------826#827# - `PyTorch DataLoader documentation <https://pytorch.org/docs/stable/data.html>`_828# - `Pin Memory and Non-blocking Transfer Tutorial <https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html>`_829# - `PyTorch Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_830831832