CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/distributed_optim_torchscript.rst
Views: 494
Distributed Optimizer with TorchScript support
==============================================================

.. note:: Distributed Optimizer with TorchScript support is introduced in PyTorch 1.8
    as a beta feature. This API is subject to change.

In this recipe, you will learn:

- The high-level idea of distributed optimizer with TorchScript support and what this feature brings
- How to write customized distributed optimizer that enables TorchScript support


Requirements
------------

- PyTorch 1.8+
- `Getting Started With Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`_


What is Distributed Optimizer?
------------------------------------

`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_ takes a list of remote
parameters (RRef) and runs the optimizer locally on the workers where the parameters live, which is commonly used together
with Distributed RPC/Autograd to do model parallel training. It could use any of the local optimizer algorithms (either
pre-defined algorithms provided in ``torch.optim`` or custom defined ones) to apply the gradients on each worker.


What is Distributed Optimizer with TorchScript support?
-------------------------------------------------------

Distributed Optimizer are widely used in distributed model parallel training, and in some
common use cases, training need to be done in multithreaded manner instead of multiprocess
due to performance concern and resource utilizations (or at least partially multithreaded,
i.e. Parameter Server hosting part of the model and parameters, with new thread updating the
parameters per request). PyTorch itself does not support multithreaded training natively as
it suffers from the Python's Global Interpreter Lock (GIL), but it could leverage 
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ to get rid of GIL and run the
model in a multithreaded way. 

For critical model training workloads, improving the training performance is an
important topic. Researchers often would like to implement different optimization strategies
with the graph representation (i.e. via operator fusion) or implement custom operator kernels
in order to speed up training.

Distributed Optimizer with TorchScript support could help getting rid of GIL, thus improve
PyTorch's training performance in the multithreaded environment, it also unlocks the potential
to further enhance the performance by using advanced compiler technologies that TorchScript
offers (i.e. CPU/GPU fusion).


How to write a customized distributed optimizer with TorchScript support?
-------------------------------------------------------------------------

The code below shows how to write a customized distributed optimizer given an existing local
optimizer implementation, which unlocks the TorchScript benefits including GIL removal and
performance improvement opportunities.

Suppose that you already have a local optimizer that is currently used during training,
In this case we will use `quasi-hyperbolic momentum (QHM) <https://github.com/facebookresearch/qhoptim/blob/e81dea3f2765780cf4fbb90b87b22ba7604b8625/qhoptim/pyt/qhm.py#L12>`_
as an example to show how to enable the TorchScript support, note that it also applies
to any custom optimizers that inherits from ``torch.optim.Optimizer``.

First, we need to separate the computation and state management from the optimizer implementation,
this is so that we could extract the computation part and make it a free function, which is
TorchScript friendly. It has two benefits: 1. The computation logic becomes easier to inspect,
it allows us to quickly turn the parameter update/computation part into TorchScript, and utilize
TorchScript IR to do further optimizations (operator fusion, etc.) 2. Distributed Optimizer
underlying is using a different mechanisms to get gradients and update parameters (we store
gradients separately instead of directly populating the ``param.grad`` field during backward).
Separating the computation allows distributed optimizer to enable the possibility of optimizer
update in multithreaded mode, as it eliminates the possible race condition to ``param.grad``.


::

    import torch
    from torch import Tensor
    from typing import List


    def qhm_update(params: List[Tensor],
                dp_list: List[Tensor],
                momentum_buffer_list: List[Tensor],
                lr: float,
                nu: float,
                weight_decay: float,
                weight_decay_type: str,
                momentum: float):

        for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
            if weight_decay != 0:
                if weight_decay_type == "grad":
                    d_p.add_(weight_decay, p)
                elif weight_decay_type == "direct":
                    p.mul_(1.0 - lr * weight_decay)
                else:
                    raise ValueError("Invalid weight decay type provided")

            momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)

            p.data.add_(-lr * nu, momentum_buffer)
            p.data.add_(-lr * (1.0 - nu), d_p)



Next we will define a distributed functional optimizer with TorchScript compatability to manage
the optimizer states and calls into the TorchScript compatible update function we defined above. 
Note that a few conventions are different from normal custom optimizers: 1. We don't inherit
``torch.optim.Optimizer`` as TorchScript does not support polymorphism 2. ``step`` takes gradients
list instead of the loss closure.

::

    import torch
    from torch import Tensor
    from typing import List, Optional, Dict

    # define this as a TorchScript class
    @torch.jit.script
    class FunctionalQHM(object):
        def __init__(self,
                    params: List[Tensor],
                    lr: float,
                    momentum: float,
                    nu: float,
                    weight_decay: float = 0.0,
                    weight_decay_type: str = "grad"):
            if lr < 0.0:
                raise ValueError("Invalid learning rate: {}".format(lr))
            if momentum < 0.0:
                raise ValueError("Invalid momentum value: {}".format(momentum))
            if weight_decay < 0.0:
                raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
            if weight_decay_type not in ("grad", "direct"):
                raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))

            self.defaults = {
                "lr": lr,
                "momentum": momentum,
                "nu": nu,
                "weight_decay": weight_decay,
            }
            self.weight_decay_type = weight_decay_type

            # NOTE: we only have one param_group here and don't allow user to add additional
            # param group as it's not a common use case.
            self.param_group = {"params": params}

            self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

        def step(self, gradients: List[Optional[Tensor]]):
            params = self.param_group['params']
            params_with_grad = []
            grads = []
            momentum_buffer_list: List[Tensor] = []

            if len(params) != len(gradients):
                raise ValueError(
                    "the gradients passed in does not equal to the size of the parameters!"
                    + f"Params length: {len(params)}. "
                    + f"Gradients length: {len(gradients)}"
                )

            for param, gradient in zip(self.param_group['params'], gradients):
                if gradient is not None:
                    params_with_grad.append(param)
                    grads.append(gradient)
                    state = self.state[param]
                    state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)
                    momentum_buffer_list.append(state['momentum_buffer'])

            # calls into the update function we just defined
            with torch.no_grad():
                qhm_update(params_with_grad,
                        grads,
                        momentum_buffer_list,
                        self.defaults['lr'],
                        self.defaults['nu'],
                        self.defaults['weight_decay'],
                        self.weight_decay_type,
                        self.defaults['momentum'])



Finally, we register our newly defined distributed functional optimizer into the ``functional_optim_map``
This is so that the ``DistributedOptimizer`` will try to pick up our custom implementation instead of the
pre-defined default ones.

::

    from torch.distributed.optim import DistributedOptimizer

    DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM

Now you can use the ``QHM`` optimizer as normal in distributed training by passing it to
`DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`_


::

    ...
    remote_params_list = [...]
    dist_optim = DistributedOptimizer(
        QHM, remote_params_list, *args, **kwargs
    )

DistributedOptimizer will automatically transform the QHM optimizer into the ``FunctionalQHM`` under the hood,
and enable the TorchScript support. This will unlock the performance that boosted by multithreaded training
and also give more potentials for further improvements (i.e. TorchScript fusion, etc.)

Note that majority of PyTorch built-in optimizers are already using this methodology to speed up distributed
training. If you see warning about some optimizers haven't been converted yet, you can write your own conversion
by following this recipe.