Advanced Model Training with Fully Sharded Data Parallel (FSDP)

**Author**: `Hamid Shojanazeri <>`__, `Less
Wright <>`__, `Rohan Varma
<>`__, `Yanli Zhao

This tutorial introduces more advanced features of Fully Sharded Data Parallel
(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please
refer to the `FSDP getting started tutorial

In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text
summarization as a working example. 

The example uses Wikihow and for simplicity, we will showcase the training on a
single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on
large scale FSDP training on a multi-node cluster, please stay tuned for that on
the PyTorch medium channel.

FSDP is a production ready package with focus on ease of use, performance, and
long-term support.  One of the main benefits of FSDP is reducing the memory
footprint on each GPU. This enables training of larger models with lower total
memory vs DDP, and leverages the overlap of computation and communication to
train models efficiently. 
This reduced memory pressure can be leveraged to either train larger models or
increase batch size, potentially helping overall training throughput.  You can
read more about PyTorch FSDP `here

FSDP Features in This Tutorial
* Transformer Auto Wrap Policy
* Mixed Precision
* Initializing FSDP Model on Device
* Sharding Strategy
* Backward Prefetch
* Model Checkpoint Saving via Streaming to CPU

Recap on How FSDP Works

At a high level FDSP works as follow:

*In constructor*

* Shard model parameters and each rank only keeps its own shard

*In forward pass*

* Run `all_gather` to collect all shards from all ranks to recover the full
  parameter for this FSDP unit Run forward computation
* Discard non-owned parameter shards it has just collected to free memory

*In backward pass*

* Run `all_gather` to collect all shards from all ranks to recover the full
  parameter in this FSDP unit Run backward computation
* Discard non-owned parameters to free memory. 
* Run reduce_scatter to sync gradients

Fine-tuning HF T5
HF T5 pre-trained models are available in four different sizes, ranging from
small with 60 Million parameters to XXL with 11 Billion parameters. In this
tutorial, we demonstrate the fine-tuning of a T5 3B with FSDP for text
summarization using WikiHow dataset.  The main focus of this tutorial is to
highlight different available features in FSDP that are helpful for training
large scale model above 3B parameters. Also, we cover specific features for
Transformer based models. The code for this tutorial is available in  `Pytorch


1.1 Install PyTorch Nightlies

We will install PyTorch nightlies, as some of the features such as activation
checkpointing is available in nightlies and will be added in next PyTorch
release after 1.12.

.. code-block:: bash 

    pip3 install --pre torch torchvision torchaudio -f

1.2 Dataset Setup

Please create a `data` folder, download the WikiHow dataset from `wikihowAll.csv
<>`__  and
`wikihowSep.cs <>`__,
and place them in the `data` folder.  We will use the wikihow dataset from

Next, we add the following code snippets to a Python script “”.

.. note::
   The full source code for this tutorial is available in `PyTorch examples

1.3  Import necessary packages:

.. code-block:: python

    import os
    import argparse
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from transformers import AutoTokenizer, GPT2TokenizerFast
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    import functools
    from torch.optim.lr_scheduler import StepLR
    import torch.nn.functional as F
    import torch.distributed as dist
    import torch.multiprocessing as mp
    from torch.nn.parallel import DistributedDataParallel as DDP
    from import DistributedSampler
    from transformers.models.t5.modeling_t5 import T5Block

    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (

    from torch.distributed.fsdp import (
        FullyShardedDataParallel as FSDP,
    from torch.distributed.fsdp.wrap import (
    from functools import partial
    from import DataLoader
    from pathlib import Path
    from summarization_dataset import *
    from transformers.models.t5.modeling_t5 import T5Block
    from typing import Type
    import time
    import tqdm
    from datetime import datetime

1.4 Distributed training setup. 
Here we use two helper functions to initialize the processes for distributed
training,  and then to clean up after training completion.  In this tutorial, we
are going to use torch elastic, using `torchrun
<>`__ , which will set the
worker `RANK` and `WORLD_SIZE` automatically.

.. code-block:: python

    def setup():
        # initialize the process group

    def cleanup():

2.1  Set up the HuggingFace T5 model:

.. code-block:: python

    def setup_model(model_name):
        model = T5ForConditionalGeneration.from_pretrained(model_name)
        tokenizer =  T5Tokenizer.from_pretrained(model_name)
        return model, tokenizer

We also, add couple of helper functions here for date and formatting memory

.. code-block:: python

    def get_date_of_run():
        """create date and time for file save uniqueness
        example: 2022-05-07-08:31:12_PM'
        date_of_run ="%Y-%m-%d-%I:%M:%S_%p")
        print(f"--> current date and time of run = {date_of_run}")
        return date_of_run
    def format_metrics_to_gb(item):
        """quick function to format numbers to gigabyte and round to 4 digit precision"""
        metric_num = item / g_gigabyte
        metric_num = round(metric_num, ndigits=4)
        return metric_num

2.2 Define a train function:

.. code-block:: python

    def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
        local_rank = int(os.environ['LOCAL_RANK'])
        fsdp_loss = torch.zeros(2).to(local_rank)

        if sampler:
        if rank==0:
            inner_pbar = tqdm.tqdm(
                range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        for batch in train_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
            loss = output["loss"]
            fsdp_loss[0] += loss.item()
            fsdp_loss[1] += len(batch)
            if rank==0:

        dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
        train_accuracy = fsdp_loss[0] / fsdp_loss[1]

        if rank == 0:
                    f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
        return train_accuracy

2.3 Define a validation function:

.. code-block:: python

    def validation(model, rank, world_size, val_loader):
        correct = 0
        local_rank = int(os.environ['LOCAL_RANK'])
        fsdp_loss = torch.zeros(3).to(local_rank)
        if rank == 0:
            inner_pbar = tqdm.tqdm(
                range(len(val_loader)), colour="green", desc="Validation Epoch"
        with torch.no_grad():
            for batch in val_loader:
                for key in batch.keys():
                    batch[key] = batch[key].to(local_rank)
                output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
                fsdp_loss[0] += output["loss"].item()  # sum up batch loss
                fsdp_loss[1] += len(batch)

                if rank==0:

        dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
        val_loss = fsdp_loss[0] / fsdp_loss[1]
        if rank == 0:
            print(f"Validation Loss: {val_loss:.4f}")
        return val_loss

2.4 Define a distributed train function that wraps the model in FSDP:

.. code-block:: python

    def fsdp_main(args):

        model, tokenizer = setup_model("t5-base")

        local_rank = int(os.environ['LOCAL_RANK'])
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])

        dataset = load_dataset('wikihow', 'all', data_dir='data/')
        print("Size of train dataset: ", dataset['train'].shape)
        print("Size of Validation dataset: ", dataset['validation'].shape)

        #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
        train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) 
        val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

        sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
        sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)


        train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
        test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
        cuda_kwargs = {'num_workers': 2,
                        'pin_memory': True,
                        'shuffle': False}

        train_loader =,**train_kwargs)
        val_loader =, **test_kwargs)

        t5_auto_wrap_policy = functools.partial(
        sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3

        #init_start_event = torch.cuda.Event(enable_timing=True)
        #init_end_event = torch.cuda.Event(enable_timing=True)


        bf16_ready = (
        and torch.cuda.is_bf16_supported()
        and LooseVersion(torch.version.cuda) >= "11.0"
        and dist.is_nccl_available()
        and nccl.version() >= (2, 10)

        if bf16_ready:
            mp_policy = bfSixteen
            mp_policy = None # defaults to fp32

        # model is on CPU before input to FSDP
        model = FSDP(model,

        optimizer = optim.AdamW(model.parameters(),

        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
        best_val_loss = float("inf")
        curr_val_loss = float("inf")
        file_save_name = "T5-model-"

        if rank == 0:
            time_of_run = get_date_of_run()
            dur = []
            train_acc_tracking = []
            val_acc_tracking = []
            training_start_time = time.time()

        if rank == 0 and args.track_memory:
            mem_alloc_tracker = []
            mem_reserved_tracker = []

        for epoch in range(1, args.epochs + 1):
            t0 = time.time()
            train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
            if args.run_validation:
                curr_val_loss = validation(model, rank, world_size, val_loader)

            if rank == 0:

                print(f"--> epoch {epoch} completed...entering save and stats zone")

                dur.append(time.time() - t0)

                if args.run_validation:

                if args.track_memory:
                print(f"completed save and stats zone...")

            if args.save_model and curr_val_loss < best_val_loss:

                # save
                if rank == 0:
                    print(f"--> entering save model state")

                save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                with FSDP.state_dict_type(
                    model, StateDictType.FULL_STATE_DICT, save_policy
                    cpu_state = model.state_dict()
                #print(f"saving process: rank {rank}  done w state_dict")

                if rank == 0:
                    print(f"--> saving model ...")
                    currEpoch = (
                        "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                    print(f"--> attempting to save model prefix {currEpoch}")
                    save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                    print(f"--> saving as model name {save_name}")

          , save_name)

            if curr_val_loss < best_val_loss:

                best_val_loss = curr_val_loss
                if rank==0:
                    print(f"-->>>> New Val Loss Record: {best_val_loss}")


2.5 Parse the arguments and set the main function:

.. code-block:: python

    if __name__ == '__main__':
        # Training settings
        parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
        parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                            help='input batch size for training (default: 64)')
        parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                            help='input batch size for testing (default: 1000)')
        parser.add_argument('--epochs', type=int, default=2, metavar='N',
                            help='number of epochs to train (default: 3)')
        parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                            help='learning rate (default: .002)')
        parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                            help='Learning rate step gamma (default: 0.7)')
        parser.add_argument('--no-cuda', action='store_true', default=False,
                            help='disables CUDA training')
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')
        parser.add_argument('--track_memory', action='store_false', default=True,
                            help='track the gpu memory')
        parser.add_argument('--run_validation', action='store_false', default=True,
                            help='running the validation')
        parser.add_argument('--save-model', action='store_false', default=True,
                            help='For Saving the current Model')
        args = parser.parse_args()



To run the the training using torchrun:

.. code-block:: bash 

    torchrun --nnodes 1 --nproc_per_node 4

.. _transformer_wrapping_policy:

Transformer Wrapping Policy

As discussed in the `previous tutorial
auto_wrap_policy is one of the FSDP features that make it easy to automatically
shard a given model and put the model, optimizer and gradient shards into
distinct FSDP units.

For some architectures such as Transformer encoder-decoders, some parts of the
model such as embedding table is being shared with both encoder and decoder.  In
this case, we need to place the embedding table in the outer FSDP unit so that
it could be accessed from both encoder and decoder.  In addition, by registering
the layer class for a transformer, the sharding plan can be made much more
communication efficient.  In PyTorch 1.12, FSDP added this support and now we
have a wrapping policy for transfomers.

It can be created as follows, where the T5Block represents the T5 transformer
layer class (holding MHSA and FFN).  

.. code-block:: python

    t5_auto_wrap_policy = functools.partial(

    model = FSDP(model,

To see the wrapped model, you can easily print the model and visually inspect
the sharding and FSDP units as well.

Mixed Precision
FSDP supports flexible mixed precision training allowing for arbitrary reduced
precision types (such as fp16 or bfloat16). Currently BFloat16 is only available
on Ampere GPUs, so you need to confirm native support before you use it. On
V100s for example, BFloat16 can still be run but due to it running non-natively,
it can result in significant slowdowns.

To check if BFloat16 is natively supported, you can use the following :

.. code-block:: python
    bf16_ready = (
        and torch.cuda.is_bf16_supported() 
        and LooseVersion(torch.version.cuda) >= "11.0"
        and dist.is_nccl_available()
        and nccl.version() >= (2, 10)

One of the advantages of mixed percision in FSDP is providing granular control
over different precision levels for parameters, gradients, and buffers as

.. code-block:: python

    fpSixteen = MixedPrecision(
        # Gradient communication precision.
        # Buffer precision.

    bfSixteen = MixedPrecision(
        # Gradient communication precision.
        # Buffer precision.

    fp32_policy = MixedPrecision(
        # Gradient communication precision.
        # Buffer precision.

Note that if a certain type (parameter, reduce, buffer) is not specified, they
will not be casted at all.

This flexibility allows users fine grained control, such as only setting
gradient communication to happen in reduced precision, and all parameters /
buffer computation to be done in full precision. This is potentially useful in
cases where intra-node communication is the main bottleneck and parameters /
buffers must be in full precision to avoid accuracy issues. This can be done
with the following policy:

.. code-block:: bash

    grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper:

.. code-block:: python

     model = FSDP(model,

In our experiments, we have observed up to 4x speed up by using BFloat16 for
training and memory reduction of approximately 30% in some experiments that can
be used for batch size increases.

Intializing FSDP Model on Device
In 1.12, FSDP supports a `device_id` argument meant to initialize input CPU
module on the device given by `device_id`. This is useful when the entire model
does not fit on a single GPU, but fits in a host's CPU memory. When `device_id`
is specified, FSDP will move the model to the specified device on a per-FSDP
unit basis, avoiding GPU OOM issues while initializing several times faster than
CPU-based initialization:

.. code-block:: python


     model = FSDP(model,

Sharding Strategy
FSDP sharding strategy by default is set to fully shard the model parameters,
gradients and optimizer states get sharded across all ranks. (also termed Zero3
sharding). In case you are interested to have the Zero2 sharding strategy, where
only optimizer states and gradients are sharded, FSDP support this feature by
passing the Sharding strategy by using  "ShardingStrategy.SHARD_GRAD_OP",
instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization  as follows:

.. code-block:: python


     model = FSDP(model,
            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

This will reduce the communication overhead in FSDP, in this case, it holds full
parameters after forward and through the backwards pass. 

This saves an all_gather during backwards so there is less communication at the
cost of a higher memory footprint. Note that full model params are freed at the
end of backwards and all_gather will happen on the next forward pass.

Backward Prefetch
The backward prefetch setting controls the timing of when the next FSDP unit's
parameters should be requested.  By setting it to `BACKWARD_PRE`, the next
FSDP's unit params can begin to be requested and arrive sooner before the
computation of the current unit starts. This overlaps the `all_gather`
communication and gradient computation which can increase the training speed in
exchange for slightly higher memory consumption. It can be utilized in the FSDP
wrapper in 2.4 as follows:

.. code-block:: python


     model = FSDP(model,
            backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
`backward_prefetch` has two modes, `BACKWARD_PRE` and `BACKWARD_POST`.
`BACKWARD_POST` means that the next FSDP unit's params will not be requested
until the current FSDP unit processing is complete, thus minimizing memory
overhead.  In some cases, using `BACKWARD_PRE` can increase model training speed
up to 2-10%, with even higher speed improvements noted for larger models. 

Model Checkpoint Saving, by streaming to the Rank0 CPU
To save model checkpoints using FULL_STATE_DICT saving which saves model in the
same fashion as a local model, PyTorch 1.12 offers a few utilities to support
the saving of larger models.

First, a FullStateDictConfig can be specified, allowing the state_dict to be
populated on rank 0 only and offloaded to the CPU.

When using this configuration, FSDP will allgather model parameters, offloading
them to the CPU one by one, only on rank 0. When the state_dict is finally
saved, it will only be populated on rank 0 and contain CPU tensors. This avoids
potential OOM for models that are larger than a single GPU memory and allows
users to checkpoint models whose size is roughly the available CPU RAM on the
user's machine.

This feature can be run as follows:

.. code-block:: python

    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
                cpu_state = model.state_dict()
    if rank == 0:
     save_name = file_save_name + "-" + time_of_run + "-" + currEpoch, save_name)


In this tutorial, we have introduced many new features for FSDP available in
Pytorch 1.12 and used HF T5 as the running example.  Using the proper wrapping
policy especially for transformer models, along with mixed precision and
backward prefetch should speed up your training runs. Also, features such as
initializing the model on device, and checkpoint saving via streaming to CPU
should help to avoid OOM error in dealing with large models. 

We are actively working to add new features to FSDP for the next release. If
you have feedback, feature requests, questions or are encountering issues
using FSDP, please feel free to contact us by opening an issue in the
`PyTorch Github repository <>`__.