Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/coding_ddpg.py
Views: 712
# -*- coding: utf-8 -*-1"""2TorchRL objectives: Coding a DDPG loss3======================================4**Author**: `Vincent Moens <https://github.com/vmoens>`_56"""78##############################################################################9# Overview10# --------11#12# TorchRL separates the training of RL algorithms in various pieces that will be13# assembled in your training script: the environment, the data collection and14# storage, the model and finally the loss function.15#16# TorchRL losses (or "objectives") are stateful objects that contain the17# trainable parameters (policy and value models).18# This tutorial will guide you through the steps to code a loss from the ground up19# using TorchRL.20#21# To this aim, we will be focusing on DDPG, which is a relatively straightforward22# algorithm to code.23# `Deep Deterministic Policy Gradient <https://arxiv.org/abs/1509.02971>`_ (DDPG)24# is a simple continuous control algorithm. It consists in learning a25# parametric value function for an action-observation pair, and26# then learning a policy that outputs actions that maximize this value27# function given a certain observation.28#29# What you will learn:30#31# - how to write a loss module and customize its value estimator;32# - how to build an environment in TorchRL, including transforms33# (for example, data normalization) and parallel execution;34# - how to design a policy and value network;35# - how to collect data from your environment efficiently and store them36# in a replay buffer;37# - how to store trajectories (and not transitions) in your replay buffer);38# - how to evaluate your model.39#40# Prerequisites41# ~~~~~~~~~~~~~42#43# This tutorial assumes that you have completed the44# `PPO tutorial <reinforcement_ppo.html>`_ which gives45# an overview of the TorchRL components and dependencies, such as46# :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`,47# although it should be48# sufficiently transparent to be understood without a deep understanding of49# these classes.50#51# .. note::52# We do not aim at giving a SOTA implementation of the algorithm, but rather53# to provide a high-level illustration of TorchRL's loss implementations54# and the library features that are to be used in the context of55# this algorithm.56#57# Imports and setup58# -----------------59#60# .. code-block:: bash61#62# %%bash63# pip3 install torchrl mujoco glfw6465# sphinx_gallery_start_ignore66import warnings6768warnings.filterwarnings("ignore")69from torch import multiprocessing7071# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside72# `__main__` method call, but for the easy of reading the code switch to fork73# which is also a default spawn method in Google's Colaboratory74try:75multiprocessing.set_start_method("fork")76except RuntimeError:77pass7879# sphinx_gallery_end_ignore808182import torch83import tqdm848586###############################################################################87# We will execute the policy on CUDA if available88is_fork = multiprocessing.get_start_method() == "fork"89device = (90torch.device(0)91if torch.cuda.is_available() and not is_fork92else torch.device("cpu")93)94collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA9596###############################################################################97# TorchRL :class:`~torchrl.objectives.LossModule`98# -----------------------------------------------99#100# TorchRL provides a series of losses to use in your training scripts.101# The aim is to have losses that are easily reusable/swappable and that have102# a simple signature.103#104# The main characteristics of TorchRL losses are:105#106# - They are stateful objects: they contain a copy of the trainable parameters107# such that ``loss_module.parameters()`` gives whatever is needed to train the108# algorithm.109# - They follow the ``TensorDict`` convention: the :meth:`torch.nn.Module.forward`110# method will receive a TensorDict as input that contains all the necessary111# information to return a loss value.112#113# >>> data = replay_buffer.sample()114# >>> loss_dict = loss_module(data)115#116# - They output a :class:`tensordict.TensorDict` instance with the loss values117# written under a ``"loss_<smth>"`` where ``smth`` is a string describing the118# loss. Additional keys in the ``TensorDict`` may be useful metrics to log during119# training time.120#121# .. note::122# The reason we return independent losses is to let the user use a different123# optimizer for different sets of parameters for instance. Summing the losses124# can be simply done via125#126# >>> loss_val = sum(loss for key, loss in loss_dict.items() if key.startswith("loss_"))127#128# The ``__init__`` method129# ~~~~~~~~~~~~~~~~~~~~~~~130#131# The parent class of all losses is :class:`~torchrl.objectives.LossModule`.132# As many other components of the library, its :meth:`~torchrl.objectives.LossModule.forward` method expects133# as input a :class:`tensordict.TensorDict` instance sampled from an experience134# replay buffer, or any similar data structure. Using this format makes it135# possible to re-use the module across136# modalities, or in complex settings where the model needs to read multiple137# entries for instance. In other words, it allows us to code a loss module that138# is oblivious to the data type that is being given to is and that focuses on139# running the elementary steps of the loss function and only those.140#141# To keep the tutorial as didactic as we can, we'll be displaying each method142# of the class independently and we'll be populating the class at a later143# stage.144#145# Let us start with the :meth:`~torchrl.objectives.LossModule.__init__`146# method. DDPG aims at solving a control task with a simple strategy:147# training a policy to output actions that maximize the value predicted by148# a value network. Hence, our loss module needs to receive two networks in its149# constructor: an actor and a value networks. We expect both of these to be150# TensorDict-compatible objects, such as151# :class:`tensordict.nn.TensorDictModule`.152# Our loss function will need to compute a target value and fit the value153# network to this, and generate an action and fit the policy such that its154# value estimate is maximized.155#156# The crucial step of the :meth:`LossModule.__init__` method is the call to157# :meth:`~torchrl.LossModule.convert_to_functional`. This method will extract158# the parameters from the module and convert it to a functional module.159# Strictly speaking, this is not necessary and one may perfectly code all160# the losses without it. However, we encourage its usage for the following161# reason.162#163# The reason TorchRL does this is that RL algorithms often execute the same164# model with different sets of parameters, called "trainable" and "target"165# parameters.166# The "trainable" parameters are those that the optimizer needs to fit. The167# "target" parameters are usually a copy of the former's with some time lag168# (absolute or diluted through a moving average).169# These target parameters are used to compute the value associated with the170# next observation. One the advantages of using a set of target parameters171# for the value model that do not match exactly the current configuration is172# that they provide a pessimistic bound on the value function being computed.173# Pay attention to the ``create_target_params`` keyword argument below: this174# argument tells the :meth:`~torchrl.objectives.LossModule.convert_to_functional`175# method to create a set of target parameters in the loss module to be used176# for target value computation. If this is set to ``False`` (see the actor network177# for instance) the ``target_actor_network_params`` attribute will still be178# accessible but this will just return a **detached** version of the179# actor parameters.180#181# Later, we will see how the target parameters should be updated in TorchRL.182#183184from tensordict.nn import TensorDictModule, TensorDictSequential185186187def _init(188self,189actor_network: TensorDictModule,190value_network: TensorDictModule,191) -> None:192super(type(self), self).__init__()193194self.convert_to_functional(195actor_network,196"actor_network",197create_target_params=True,198)199self.convert_to_functional(200value_network,201"value_network",202create_target_params=True,203compare_against=list(actor_network.parameters()),204)205206self.actor_in_keys = actor_network.in_keys207208# Since the value we'll be using is based on the actor and value network,209# we put them together in a single actor-critic container.210actor_critic = ActorCriticWrapper(actor_network, value_network)211self.actor_critic = actor_critic212self.loss_function = "l2"213214215###############################################################################216# The value estimator loss method217# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~218#219# In many RL algorithm, the value network (or Q-value network) is trained based220# on an empirical value estimate. This can be bootstrapped (TD(0), low221# variance, high bias), meaning222# that the target value is obtained using the next reward and nothing else, or223# a Monte-Carlo estimate can be obtained (TD(1)) in which case the whole224# sequence of upcoming rewards will be used (high variance, low bias). An225# intermediate estimator (TD(:math:`\lambda`)) can also be used to compromise226# bias and variance.227# TorchRL makes it easy to use one or the other estimator via the228# :class:`~torchrl.objectives.utils.ValueEstimators` Enum class, which contains229# pointers to all the value estimators implemented. Let us define the default230# value function here. We will take the simplest version (TD(0)), and show later231# on how this can be changed.232233from torchrl.objectives.utils import ValueEstimators234235default_value_estimator = ValueEstimators.TD0236237###############################################################################238# We also need to give some instructions to DDPG on how to build the value239# estimator, depending on the user query. Depending on the estimator provided,240# we will build the corresponding module to be used at train time:241242from torchrl.objectives.utils import default_value_kwargs243from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator244245246def make_value_estimator(self, value_type: ValueEstimators, **hyperparams):247hp = dict(default_value_kwargs(value_type))248if hasattr(self, "gamma"):249hp["gamma"] = self.gamma250hp.update(hyperparams)251value_key = "state_action_value"252if value_type == ValueEstimators.TD1:253self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)254elif value_type == ValueEstimators.TD0:255self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)256elif value_type == ValueEstimators.GAE:257raise NotImplementedError(258f"Value type {value_type} it not implemented for loss {type(self)}."259)260elif value_type == ValueEstimators.TDLambda:261self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp)262else:263raise NotImplementedError(f"Unknown value type {value_type}")264self._value_estimator.set_keys(value=value_key)265266267###############################################################################268# The ``make_value_estimator`` method can but does not need to be called: if269# not, the :class:`~torchrl.objectives.LossModule` will query this method with270# its default estimator.271#272# The actor loss method273# ~~~~~~~~~~~~~~~~~~~~~274#275# The central piece of an RL algorithm is the training loss for the actor.276# In the case of DDPG, this function is quite simple: we just need to compute277# the value associated with an action computed using the policy and optimize278# the actor weights to maximize this value.279#280# When computing this value, we must make sure to take the value parameters out281# of the graph, otherwise the actor and value loss will be mixed up.282# For this, the :func:`~torchrl.objectives.utils.hold_out_params` function283# can be used.284285286def _loss_actor(287self,288tensordict,289) -> torch.Tensor:290td_copy = tensordict.select(*self.actor_in_keys)291# Get an action from the actor network: since we made it functional, we need to pass the params292with self.actor_network_params.to_module(self.actor_network):293td_copy = self.actor_network(td_copy)294# get the value associated with that action295with self.value_network_params.detach().to_module(self.value_network):296td_copy = self.value_network(td_copy)297return -td_copy.get("state_action_value")298299300###############################################################################301# The value loss method302# ~~~~~~~~~~~~~~~~~~~~~303#304# We now need to optimize our value network parameters.305# To do this, we will rely on the value estimator of our class:306#307308from torchrl.objectives.utils import distance_loss309310311def _loss_value(312self,313tensordict,314):315td_copy = tensordict.clone()316317# V(s, a)318with self.value_network_params.to_module(self.value_network):319self.value_network(td_copy)320pred_val = td_copy.get("state_action_value").squeeze(-1)321322# we manually reconstruct the parameters of the actor-critic, where the first323# set of parameters belongs to the actor and the second to the value function.324target_params = TensorDict(325{326"module": {327"0": self.target_actor_network_params,328"1": self.target_value_network_params,329}330},331batch_size=self.target_actor_network_params.batch_size,332device=self.target_actor_network_params.device,333)334with target_params.to_module(self.actor_critic):335target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)336337# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`338loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)339td_error = (pred_val - target_value).pow(2)340341return loss_value, td_error, pred_val, target_value342343344###############################################################################345# Putting things together in a forward call346# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~347#348# The only missing piece is the forward method, which will glue together the349# value and actor loss, collect the cost values and write them in a ``TensorDict``350# delivered to the user.351352from tensordict import TensorDict, TensorDictBase353354355def _forward(self, input_tensordict: TensorDictBase) -> TensorDict:356loss_value, td_error, pred_val, target_value = self.loss_value(357input_tensordict,358)359td_error = td_error.detach()360td_error = td_error.unsqueeze(input_tensordict.ndimension())361if input_tensordict.device is not None:362td_error = td_error.to(input_tensordict.device)363input_tensordict.set(364"td_error",365td_error,366inplace=True,367)368loss_actor = self.loss_actor(input_tensordict)369return TensorDict(370source={371"loss_actor": loss_actor.mean(),372"loss_value": loss_value.mean(),373"pred_value": pred_val.mean().detach(),374"target_value": target_value.mean().detach(),375"pred_value_max": pred_val.max().detach(),376"target_value_max": target_value.max().detach(),377},378batch_size=[],379)380381382from torchrl.objectives import LossModule383384385class DDPGLoss(LossModule):386default_value_estimator = default_value_estimator387make_value_estimator = make_value_estimator388389__init__ = _init390forward = _forward391loss_value = _loss_value392loss_actor = _loss_actor393394395###############################################################################396# Now that we have our loss, we can use it to train a policy to solve a397# control task.398#399# Environment400# -----------401#402# In most algorithms, the first thing that needs to be taken care of is the403# construction of the environment as it conditions the remainder of the404# training script.405#406# For this example, we will be using the ``"cheetah"`` task. The goal is to make407# a half-cheetah run as fast as possible.408#409# In TorchRL, one can create such a task by relying on ``dm_control`` or ``gym``:410#411# .. code-block:: python412#413# env = GymEnv("HalfCheetah-v4")414#415# or416#417# .. code-block:: python418#419# env = DMControlEnv("cheetah", "run")420#421# By default, these environment disable rendering. Training from states is422# usually easier than training from images. To keep things simple, we focus423# on learning from states only. To pass the pixels to the ``tensordicts`` that424# are collected by :func:`env.step()`, simply pass the ``from_pixels=True``425# argument to the constructor:426#427# .. code-block:: python428#429# env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True)430#431# We write a :func:`make_env` helper function that will create an environment432# with either one of the two backends considered above (``dm-control`` or ``gym``).433#434435from torchrl.envs.libs.dm_control import DMControlEnv436from torchrl.envs.libs.gym import GymEnv437438env_library = None439env_name = None440441442def make_env(from_pixels=False):443"""Create a base ``env``."""444global env_library445global env_name446447if backend == "dm_control":448env_name = "cheetah"449env_task = "run"450env_args = (env_name, env_task)451env_library = DMControlEnv452elif backend == "gym":453env_name = "HalfCheetah-v4"454env_args = (env_name,)455env_library = GymEnv456else:457raise NotImplementedError458459env_kwargs = {460"device": device,461"from_pixels": from_pixels,462"pixels_only": from_pixels,463"frame_skip": 2,464}465env = env_library(*env_args, **env_kwargs)466return env467468469###############################################################################470# Transforms471# ~~~~~~~~~~472#473# Now that we have a base environment, we may want to modify its representation474# to make it more policy-friendly. In TorchRL, transforms are appended to the475# base environment in a specialized :class:`torchr.envs.TransformedEnv` class.476#477# - It is common in DDPG to rescale the reward using some heuristic value. We478# will multiply the reward by 5 in this example.479#480# - If we are using :mod:`dm_control`, it is also important to build an interface481# between the simulator which works with double precision numbers, and our482# script which presumably uses single precision ones. This transformation goes483# both ways: when calling :func:`env.step`, our actions will need to be484# represented in double precision, and the output will need to be transformed485# to single precision.486# The :class:`~torchrl.envs.DoubleToFloat` transform does exactly this: the487# ``in_keys`` list refers to the keys that will need to be transformed from488# double to float, while the ``in_keys_inv`` refers to those that need to489# be transformed to double before being passed to the environment.490#491# - We concatenate the state keys together using the :class:`~torchrl.envs.CatTensors`492# transform.493#494# - Finally, we also leave the possibility of normalizing the states: we will495# take care of computing the normalizing constants later on.496#497498from torchrl.envs import (499CatTensors,500DoubleToFloat,501EnvCreator,502InitTracker,503ObservationNorm,504ParallelEnv,505RewardScaling,506StepCounter,507TransformedEnv,508)509510511def make_transformed_env(512env,513):514"""Apply transforms to the ``env`` (such as reward scaling and state normalization)."""515516env = TransformedEnv(env)517518# we append transforms one by one, although we might as well create the519# transformed environment using the `env = TransformedEnv(base_env, transforms)`520# syntax.521env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling))522523# We concatenate all states into a single "observation_vector"524# even if there is a single tensor, it'll be renamed in "observation_vector".525# This facilitates the downstream operations as we know the name of the526# output tensor.527# In some environments (not half-cheetah), there may be more than one528# observation vector: in this case this code snippet will concatenate them529# all.530selected_keys = list(env.observation_spec.keys())531out_key = "observation_vector"532env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))533534# we normalize the states, but for now let's just instantiate a stateless535# version of the transform536env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True))537538env.append_transform(DoubleToFloat())539540env.append_transform(StepCounter(max_frames_per_traj))541542# We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU)543# exploration:544env.append_transform(InitTracker())545546return env547548549###############################################################################550# Parallel execution551# ~~~~~~~~~~~~~~~~~~552#553# The following helper function allows us to run environments in parallel.554# Running environments in parallel can significantly speed up the collection555# throughput. When using transformed environment, we need to choose whether we556# want to execute the transform individually for each environment, or557# centralize the data and transform it in batch. Both approaches are easy to558# code:559#560# .. code-block:: python561#562# env = ParallelEnv(563# lambda: TransformedEnv(GymEnv("HalfCheetah-v4"), transforms),564# num_workers=4565# )566# env = TransformedEnv(567# ParallelEnv(lambda: GymEnv("HalfCheetah-v4"), num_workers=4),568# transforms569# )570#571# To leverage the vectorization capabilities of PyTorch, we adopt572# the first method:573#574575576def parallel_env_constructor(577env_per_collector,578transform_state_dict,579):580if env_per_collector == 1:581582def make_t_env():583env = make_transformed_env(make_env())584env.transform[2].init_stats(3)585env.transform[2].loc.copy_(transform_state_dict["loc"])586env.transform[2].scale.copy_(transform_state_dict["scale"])587return env588589env_creator = EnvCreator(make_t_env)590return env_creator591592parallel_env = ParallelEnv(593num_workers=env_per_collector,594create_env_fn=EnvCreator(lambda: make_env()),595create_env_kwargs=None,596pin_memory=False,597)598env = make_transformed_env(parallel_env)599# we call `init_stats` for a limited number of steps, just to instantiate600# the lazy buffers.601env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1])602env.transform[2].load_state_dict(transform_state_dict)603return env604605606# The backend can be ``gym`` or ``dm_control``607backend = "gym"608609###############################################################################610# .. note::611#612# ``frame_skip`` batches multiple step together with a single action613# If > 1, the other frame counts (for example, frames_per_batch, total_frames)614# need to be adjusted to have a consistent total number of frames collected615# across experiments. This is important as raising the frame-skip but keeping the616# total number of frames unchanged may seem like cheating: all things compared,617# a dataset of 10M elements collected with a frame-skip of 2 and another with618# a frame-skip of 1 actually have a ratio of interactions with the environment619# of 2:1! In a nutshell, one should be cautious about the frame-count of a620# training script when dealing with frame skipping as this may lead to621# biased comparisons between training strategies.622#623# Scaling the reward helps us control the signal magnitude for a more624# efficient learning.625reward_scaling = 5.0626627###############################################################################628# We also define when a trajectory will be truncated. A thousand steps (500 if629# frame-skip = 2) is a good number to use for the cheetah task:630631max_frames_per_traj = 500632633###############################################################################634# Normalization of the observations635# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~636#637# To compute the normalizing statistics, we run an arbitrary number of random638# steps in the environment and compute the mean and standard deviation of the639# collected observations. The :func:`ObservationNorm.init_stats()` method can640# be used for this purpose. To get the summary statistics, we create a dummy641# environment and run it for a given number of steps, collect data over a given642# number of steps and compute its summary statistics.643#644645646def get_env_stats():647"""Gets the stats of an environment."""648proof_env = make_transformed_env(make_env())649t = proof_env.transform[2]650t.init_stats(init_env_steps)651transform_state_dict = t.state_dict()652proof_env.close()653return transform_state_dict654655656###############################################################################657# Normalization stats658# ~~~~~~~~~~~~~~~~~~~659# Number of random steps used as for stats computation using ``ObservationNorm``660661init_env_steps = 5000662663transform_state_dict = get_env_stats()664665###############################################################################666# Number of environments in each data collector667env_per_collector = 4668669###############################################################################670# We pass the stats computed earlier to normalize the output of our671# environment:672673parallel_env = parallel_env_constructor(674env_per_collector=env_per_collector,675transform_state_dict=transform_state_dict,676)677678679from torchrl.data import CompositeSpec680681###############################################################################682# Building the model683# ------------------684#685# We now turn to the setup of the model. As we have seen, DDPG requires a686# value network, trained to estimate the value of a state-action pair, and a687# parametric actor that learns how to select actions that maximize this value.688#689# Recall that building a TorchRL module requires two steps:690#691# - writing the :class:`torch.nn.Module` that will be used as network,692# - wrapping the network in a :class:`tensordict.nn.TensorDictModule` where the693# data flow is handled by specifying the input and output keys.694#695# In more complex scenarios, :class:`tensordict.nn.TensorDictSequential` can696# also be used.697#698#699# The Q-Value network is wrapped in a :class:`~torchrl.modules.ValueOperator`700# that automatically sets the ``out_keys`` to ``"state_action_value`` for q-value701# networks and ``state_value`` for other value networks.702#703# TorchRL provides a built-in version of the DDPG networks as presented in the704# original paper. These can be found under :class:`~torchrl.modules.DdpgMlpActor`705# and :class:`~torchrl.modules.DdpgMlpQNet`.706#707# Since we use lazy modules, it is necessary to materialize the lazy modules708# before being able to move the policy from device to device and achieve other709# operations. Hence, it is good practice to run the modules with a small710# sample of data. For this purpose, we generate fake data from the711# environment specs.712#713714from torchrl.modules import (715ActorCriticWrapper,716DdpgMlpActor,717DdpgMlpQNet,718OrnsteinUhlenbeckProcessModule,719ProbabilisticActor,720TanhDelta,721ValueOperator,722)723724725def make_ddpg_actor(726transform_state_dict,727device="cpu",728):729proof_environment = make_transformed_env(make_env())730proof_environment.transform[2].init_stats(3)731proof_environment.transform[2].load_state_dict(transform_state_dict)732733out_features = proof_environment.action_spec.shape[-1]734735actor_net = DdpgMlpActor(736action_dim=out_features,737)738739in_keys = ["observation_vector"]740out_keys = ["param"]741742actor = TensorDictModule(743actor_net,744in_keys=in_keys,745out_keys=out_keys,746)747748actor = ProbabilisticActor(749actor,750distribution_class=TanhDelta,751in_keys=["param"],752spec=CompositeSpec(action=proof_environment.action_spec),753).to(device)754755q_net = DdpgMlpQNet()756757in_keys = in_keys + ["action"]758qnet = ValueOperator(759in_keys=in_keys,760module=q_net,761).to(device)762763# initialize lazy modules764qnet(actor(proof_environment.reset().to(device)))765return actor, qnet766767768actor, qnet = make_ddpg_actor(769transform_state_dict=transform_state_dict,770device=device,771)772773###############################################################################774# Exploration775# ~~~~~~~~~~~776#777# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`778# exploration module, as suggested in the original paper.779# Let's define the number of frames before OU noise reaches its minimum value780annealing_frames = 1_000_000781782actor_model_explore = TensorDictSequential(783actor,784OrnsteinUhlenbeckProcessModule(785spec=actor.spec.clone(),786annealing_num_steps=annealing_frames,787).to(device),788)789if device == torch.device("cpu"):790actor_model_explore.share_memory()791792793###############################################################################794# Data collector795# --------------796#797# TorchRL provides specialized classes to help you collect data by executing798# the policy in the environment. These "data collectors" iteratively compute799# the action to be executed at a given time, then execute a step in the800# environment and reset it when required.801# Data collectors are designed to help developers have a tight control802# on the number of frames per batch of data, on the (a)sync nature of this803# collection and on the resources allocated to the data collection (for example804# GPU, number of workers, and so on).805#806# Here we will use807# :class:`~torchrl.collectors.SyncDataCollector`, a simple, single-process808# data collector. TorchRL offers other collectors, such as809# :class:`~torchrl.collectors.MultiaSyncDataCollector`, which executed the810# rollouts in an asynchronous manner (for example, data will be collected while811# the policy is being optimized, thereby decoupling the training and812# data collection).813#814# The parameters to specify are:815#816# - an environment factory or an environment,817# - the policy,818# - the total number of frames before the collector is considered empty,819# - the maximum number of frames per trajectory (useful for non-terminating820# environments, like ``dm_control`` ones).821#822# .. note::823#824# The ``max_frames_per_traj`` passed to the collector will have the effect825# of registering a new :class:`~torchrl.envs.StepCounter` transform826# with the environment used for inference. We can achieve the same result827# manually, as we do in this script.828#829# One should also pass:830#831# - the number of frames in each batch collected,832# - the number of random steps executed independently from the policy,833# - the devices used for policy execution834# - the devices used to store data before the data is passed to the main835# process.836#837# The total frames we will use during training should be around 1M.838total_frames = 10_000 # 1_000_000839840###############################################################################841# The number of frames returned by the collector at each iteration of the outer842# loop is equal to the length of each sub-trajectories times the number of843# environments run in parallel in each collector.844#845# In other words, we expect batches from the collector to have a shape846# ``[env_per_collector, traj_len]`` where847# ``traj_len=frames_per_batch/env_per_collector``:848#849traj_len = 200850frames_per_batch = env_per_collector * traj_len851init_random_frames = 5000852num_collectors = 2853854from torchrl.collectors import SyncDataCollector855from torchrl.envs import ExplorationType856857collector = SyncDataCollector(858parallel_env,859policy=actor_model_explore,860total_frames=total_frames,861frames_per_batch=frames_per_batch,862init_random_frames=init_random_frames,863reset_at_each_iter=False,864split_trajs=False,865device=collector_device,866exploration_type=ExplorationType.RANDOM,867)868869###############################################################################870# Evaluator: building your recorder object871# ----------------------------------------872#873# As the training data is obtained using some exploration strategy, the true874# performance of our algorithm needs to be assessed in deterministic mode. We875# do this using a dedicated class, ``Recorder``, which executes the policy in876# the environment at a given frequency and returns some statistics obtained877# from these simulations.878#879# The following helper function builds this object:880from torchrl.trainers import Recorder881882883def make_recorder(actor_model_explore, transform_state_dict, record_interval):884base_env = make_env()885environment = make_transformed_env(base_env)886environment.transform[2].init_stats(8873888) # must be instantiated to load the state dict889environment.transform[2].load_state_dict(transform_state_dict)890891recorder_obj = Recorder(892record_frames=1000,893policy_exploration=actor_model_explore,894environment=environment,895exploration_type=ExplorationType.MEAN,896record_interval=record_interval,897)898return recorder_obj899900901###############################################################################902# We will be recording the performance every 10 batch collected903record_interval = 10904905recorder = make_recorder(906actor_model_explore, transform_state_dict, record_interval=record_interval907)908909from torchrl.data.replay_buffers import (910LazyMemmapStorage,911PrioritizedSampler,912RandomSampler,913TensorDictReplayBuffer,914)915916###############################################################################917# Replay buffer918# -------------919#920# Replay buffers come in two flavors: prioritized (where some error signal921# is used to give a higher likelihood of sampling to some items than others)922# and regular, circular experience replay.923#924# TorchRL replay buffers are composable: one can pick up the storage, sampling925# and writing strategies. It is also possible to926# store tensors on physical memory using a memory-mapped array. The following927# function takes care of creating the replay buffer with the desired928# hyperparameters:929#930931from torchrl.envs import RandomCropTensorDict932933934def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb=False):935if prb:936sampler = PrioritizedSampler(937max_capacity=buffer_size,938alpha=0.7,939beta=0.5,940)941else:942sampler = RandomSampler()943replay_buffer = TensorDictReplayBuffer(944storage=LazyMemmapStorage(945buffer_size,946scratch_dir=buffer_scratch_dir,947),948batch_size=batch_size,949sampler=sampler,950pin_memory=False,951prefetch=prefetch,952transform=RandomCropTensorDict(random_crop_len, sample_dim=1),953)954return replay_buffer955956957###############################################################################958# We'll store the replay buffer in a temporary directory on disk959960import tempfile961962tmpdir = tempfile.TemporaryDirectory()963buffer_scratch_dir = tmpdir.name964965###############################################################################966# Replay buffer storage and batch size967# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~968#969# TorchRL replay buffer counts the number of elements along the first dimension.970# Since we'll be feeding trajectories to our buffer, we need to adapt the buffer971# size by dividing it by the length of the sub-trajectories yielded by our972# data collector.973# Regarding the batch-size, our sampling strategy will consist in sampling974# trajectories of length ``traj_len=200`` before selecting sub-trajectories975# or length ``random_crop_len=25`` on which the loss will be computed.976# This strategy balances the choice of storing whole trajectories of a certain977# length with the need for providing samples with a sufficient heterogeneity978# to our loss. The following figure shows the dataflow from a collector979# that gets 8 frames in each batch with 2 environments run in parallel,980# feeds them to a replay buffer that contains 1000 trajectories and981# samples sub-trajectories of 2 time steps each.982#983# .. figure:: /_static/img/replaybuffer_traj.png984# :alt: Storing trajectories in the replay buffer985#986# Let's start with the number of frames stored in the buffer987988989def ceil_div(x, y):990return -x // (-y)991992993buffer_size = 1_000_000994buffer_size = ceil_div(buffer_size, traj_len)995996###############################################################################997# Prioritized replay buffer is disabled by default998prb = False9991000###############################################################################1001# We also need to define how many updates we'll be doing per batch of data1002# collected. This is known as the update-to-data or ``UTD`` ratio:1003update_to_data = 6410041005###############################################################################1006# We'll be feeding the loss with trajectories of length 25:1007random_crop_len = 2510081009###############################################################################1010# In the original paper, the authors perform one update with a batch of 641011# elements for each frame collected. Here, we reproduce the same ratio1012# but while realizing several updates at each batch collection. We1013# adapt our batch-size to achieve the same number of update-per-frame ratio:10141015batch_size = ceil_div(64 * frames_per_batch, update_to_data * random_crop_len)10161017replay_buffer = make_replay_buffer(1018buffer_size=buffer_size,1019batch_size=batch_size,1020random_crop_len=random_crop_len,1021prefetch=3,1022prb=prb,1023)10241025###############################################################################1026# Loss module construction1027# ------------------------1028#1029# We build our loss module with the actor and ``qnet`` we've just created.1030# Because we have target parameters to update, we _must_ create a target network1031# updater.1032#10331034gamma = 0.991035lmbda = 0.91036tau = 0.001 # Decay factor for the target network10371038loss_module = DDPGLoss(actor, qnet)10391040###############################################################################1041# let's use the TD(lambda) estimator!1042loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda)10431044###############################################################################1045# .. note::1046# Off-policy usually dictates a TD(0) estimator. Here, we use a TD(:math:`\lambda`)1047# estimator, which will introduce some bias as the trajectory that follows1048# a certain state has been collected with an outdated policy.1049# This trick, as the multi-step trick that can be used during data collection,1050# are alternative versions of "hacks" that we usually find to work well in1051# practice despite the fact that they introduce some bias in the return1052# estimates.1053#1054# Target network updater1055# ~~~~~~~~~~~~~~~~~~~~~~1056#1057# Target networks are a crucial part of off-policy RL algorithms.1058# Updating the target network parameters is made easy thanks to the1059# :class:`~torchrl.objectives.HardUpdate` and :class:`~torchrl.objectives.SoftUpdate`1060# classes. They're built with the loss module as argument, and the update is1061# achieved via a call to `updater.step()` at the appropriate location in the1062# training loop.10631064from torchrl.objectives.utils import SoftUpdate10651066target_net_updater = SoftUpdate(loss_module, eps=1 - tau)10671068###############################################################################1069# Optimizer1070# ~~~~~~~~~1071#1072# Finally, we will use the Adam optimizer for the policy and value network:10731074from torch import optim10751076optimizer_actor = optim.Adam(1077loss_module.actor_network_params.values(True, True), lr=1e-4, weight_decay=0.01078)1079optimizer_value = optim.Adam(1080loss_module.value_network_params.values(True, True), lr=1e-3, weight_decay=1e-21081)1082total_collection_steps = total_frames // frames_per_batch10831084###############################################################################1085# Time to train the policy1086# ------------------------1087#1088# The training loop is pretty straightforward now that we have built all the1089# modules we need.1090#10911092rewards = []1093rewards_eval = []10941095# Main loop10961097collected_frames = 01098pbar = tqdm.tqdm(total=total_frames)1099r0 = None1100for i, tensordict in enumerate(collector):11011102# update weights of the inference policy1103collector.update_policy_weights_()11041105if r0 is None:1106r0 = tensordict["next", "reward"].mean().item()1107pbar.update(tensordict.numel())11081109# extend the replay buffer with the new data1110current_frames = tensordict.numel()1111collected_frames += current_frames1112replay_buffer.extend(tensordict.cpu())11131114# optimization steps1115if collected_frames >= init_random_frames:1116for _ in range(update_to_data):1117# sample from replay buffer1118sampled_tensordict = replay_buffer.sample().to(device)11191120# Compute loss1121loss_dict = loss_module(sampled_tensordict)11221123# optimize1124loss_dict["loss_actor"].backward()1125gn1 = torch.nn.utils.clip_grad_norm_(1126loss_module.actor_network_params.values(True, True), 10.01127)1128optimizer_actor.step()1129optimizer_actor.zero_grad()11301131loss_dict["loss_value"].backward()1132gn2 = torch.nn.utils.clip_grad_norm_(1133loss_module.value_network_params.values(True, True), 10.01134)1135optimizer_value.step()1136optimizer_value.zero_grad()11371138gn = (gn1**2 + gn2**2) ** 0.511391140# update priority1141if prb:1142replay_buffer.update_tensordict_priority(sampled_tensordict)1143# update target network1144target_net_updater.step()11451146rewards.append(1147(1148i,1149tensordict["next", "reward"].mean().item(),1150)1151)1152td_record = recorder(None)1153if td_record is not None:1154rewards_eval.append((i, td_record["r_evaluation"].item()))1155if len(rewards_eval) and collected_frames >= init_random_frames:1156target_value = loss_dict["target_value"].item()1157loss_value = loss_dict["loss_value"].item()1158loss_actor = loss_dict["loss_actor"].item()1159rn = sampled_tensordict["next", "reward"].mean().item()1160rs = sampled_tensordict["next", "reward"].std().item()1161pbar.set_description(1162f"reward: {rewards[-1][1]: 4.2f} (r0 = {r0: 4.2f}), "1163f"reward eval: reward: {rewards_eval[-1][1]: 4.2f}, "1164f"reward normalized={rn :4.2f}/{rs :4.2f}, "1165f"grad norm={gn: 4.2f}, "1166f"loss_value={loss_value: 4.2f}, "1167f"loss_actor={loss_actor: 4.2f}, "1168f"target value: {target_value: 4.2f}"1169)11701171# update the exploration strategy1172actor_model_explore[1].step(current_frames)11731174collector.shutdown()1175del collector11761177###############################################################################1178# Experiment results1179# ------------------1180#1181# We make a simple plot of the average rewards during training. We can observe1182# that our policy learned quite well to solve the task.1183#1184# .. note::1185# As already mentioned above, to get a more reasonable performance,1186# use a greater value for ``total_frames`` for example, 1M.11871188from matplotlib import pyplot as plt11891190plt.figure()1191plt.plot(*zip(*rewards), label="training")1192plt.plot(*zip(*rewards_eval), label="eval")1193plt.legend()1194plt.xlabel("iter")1195plt.ylabel("reward")1196plt.tight_layout()11971198###############################################################################1199# Conclusion1200# ----------1201#1202# In this tutorial, we have learned how to code a loss module in TorchRL given1203# the concrete example of DDPG.1204#1205# The key takeaways are:1206#1207# - How to use the :class:`~torchrl.objectives.LossModule` class to code up a new1208# loss component;1209# - How to use (or not) a target network, and how to update its parameters;1210# - How to create an optimizer associated with a loss module.1211#1212# Next Steps1213# ----------1214#1215# To iterate further on this loss module we might consider:1216#1217# - Using `@dispatch` (see `[Feature] Distpatch IQL loss module <https://github.com/pytorch/rl/pull/1230>`_.)1218# - Allowing flexible TensorDict keys.1219#122012211222