Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/dqn_with_rnn_tutorial.py
Views: 712
# -*- coding: utf-8 -*-12"""3Recurrent DQN: Training recurrent policies4==========================================56**Author**: `Vincent Moens <https://github.com/vmoens>`_78.. grid:: 2910.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn11:class-card: card-prerequisites1213* How to incorporating an RNN in an actor in TorchRL14* How to use that memory-based policy with a replay buffer and a loss module1516.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites17:class-card: card-prerequisites1819* PyTorch v2.0.020* gym[mujoco]21* tqdm22"""2324#########################################################################25# Overview26# --------27#28# Memory-based policies are crucial not only when the observations are partially29# observable but also when the time dimension must be taken into account to30# make informed decisions.31#32# Recurrent neural network have long been a popular tool for memory-based33# policies. The idea is to keep a recurrent state in memory between two34# consecutive steps, and use this as an input to the policy along with the35# current observation.36#37# This tutorial shows how to incorporate an RNN in a policy using TorchRL.38#39# Key learnings:40#41# - Incorporating an RNN in an actor in TorchRL;42# - Using that memory-based policy with a replay buffer and a loss module.43#44# The core idea of using RNNs in TorchRL is to use TensorDict as a data carrier45# for the hidden states from one step to another. We'll build a policy that46# reads the previous recurrent state from the current TensorDict, and writes the47# current recurrent states in the TensorDict of the next state:48#49# .. figure:: /_static/img/rollout_recurrent.png50# :alt: Data collection with a recurrent policy51#52# As this figure shows, our environment populates the TensorDict with zeroed recurrent53# states which are read by the policy together with the observation to produce an54# action, and recurrent states that will be used for the next step.55# When the :func:`~torchrl.envs.utils.step_mdp` function is called, the recurrent states56# from the next state are brought to the current TensorDict. Let's see how this57# is implemented in practice.5859######################################################################60# If you are running this in Google Colab, make sure you install the following dependencies:61#62# .. code-block:: bash63#64# !pip3 install torchrl65# !pip3 install gym[mujoco]66# !pip3 install tqdm67#68# Setup69# -----70#7172# sphinx_gallery_start_ignore73import warnings7475warnings.filterwarnings("ignore")76from torch import multiprocessing7778# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside79# `__main__` method call, but for the easy of reading the code switch to fork80# which is also a default spawn method in Google's Colaboratory81try:82multiprocessing.set_start_method("fork")83except RuntimeError:84pass8586# sphinx_gallery_end_ignore8788import torch89import tqdm90from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq91from torch import nn92from torchrl.collectors import SyncDataCollector93from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer94from torchrl.envs import (95Compose,96ExplorationType,97GrayScale,98InitTracker,99ObservationNorm,100Resize,101RewardScaling,102set_exploration_type,103StepCounter,104ToTensorImage,105TransformedEnv,106)107from torchrl.envs.libs.gym import GymEnv108from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule109from torchrl.objectives import DQNLoss, SoftUpdate110111is_fork = multiprocessing.get_start_method() == "fork"112device = (113torch.device(0)114if torch.cuda.is_available() and not is_fork115else torch.device("cpu")116)117118######################################################################119# Environment120# -----------121#122# As usual, the first step is to build our environment: it helps us123# define the problem and build the policy network accordingly. For this tutorial,124# we'll be running a single pixel-based instance of the CartPole gym125# environment with some custom transforms: turning to grayscale, resizing to126# 84x84, scaling down the rewards and normalizing the observations.127#128# .. note::129# The :class:`~torchrl.envs.transforms.StepCounter` transform is accessory. Since the CartPole130# task goal is to make trajectories as long as possible, counting the steps131# can help us track the performance of our policy.132#133# Two transforms are important for the purpose of this tutorial:134#135# - :class:`~torchrl.envs.transforms.InitTracker` will stamp the136# calls to :meth:`~torchrl.envs.EnvBase.reset` by adding a ``"is_init"``137# boolean mask in the TensorDict that will track which steps require a reset138# of the RNN hidden states.139# - The :class:`~torchrl.envs.transforms.TensorDictPrimer` transform is a bit more140# technical. It is not required to use RNN policies. However, it141# instructs the environment (and subsequently the collector) that some extra142# keys are to be expected. Once added, a call to `env.reset()` will populate143# the entries indicated in the primer with zeroed tensors. Knowing that144# these tensors are expected by the policy, the collector will pass them on145# during collection. Eventually, we'll be storing our hidden states in the146# replay buffer, which will help us bootstrap the computation of the147# RNN operations in the loss module (which would otherwise be initiated148# with 0s). In summary: not including this transform will not impact hugely149# the training of our policy, but it will make the recurrent keys disappear150# from the collected data and the replay buffer, which will in turn lead to151# a slightly less optimal training.152# Fortunately, the :class:`~torchrl.modules.LSTMModule` we propose is153# equipped with a helper method to build just that transform for us, so154# we can wait until we build it!155#156157env = TransformedEnv(158GymEnv("CartPole-v1", from_pixels=True, device=device),159Compose(160ToTensorImage(),161GrayScale(),162Resize(84, 84),163StepCounter(),164InitTracker(),165RewardScaling(loc=0.0, scale=0.1),166ObservationNorm(standard_normal=True, in_keys=["pixels"]),167),168)169170######################################################################171# As always, we need to initialize manually our normalization constants:172#173env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])174td = env.reset()175176######################################################################177# Policy178# ------179#180# Our policy will have 3 components: a :class:`~torchrl.modules.ConvNet`181# backbone, an :class:`~torchrl.modules.LSTMModule` memory layer and a shallow182# :class:`~torchrl.modules.MLP` block that will map the LSTM output onto the183# action values.184#185# Convolutional network186# ~~~~~~~~~~~~~~~~~~~~~187#188# We build a convolutional network flanked with a :class:`torch.nn.AdaptiveAvgPool2d`189# that will squash the output in a vector of size 64. The :class:`~torchrl.modules.ConvNet`190# can assist us with this:191#192193feature = Mod(194ConvNet(195num_cells=[32, 32, 64],196squeeze_output=True,197aggregator_class=nn.AdaptiveAvgPool2d,198aggregator_kwargs={"output_size": (1, 1)},199device=device,200),201in_keys=["pixels"],202out_keys=["embed"],203)204######################################################################205# we execute the first module on a batch of data to gather the size of the206# output vector:207#208n_cells = feature(env.reset())["embed"].shape[-1]209210######################################################################211# LSTM Module212# ~~~~~~~~~~~213#214# TorchRL provides a specialized :class:`~torchrl.modules.LSTMModule` class215# to incorporate LSTMs in your code-base. It is a :class:`~tensordict.nn.TensorDictModuleBase`216# subclass: as such, it has a set of ``in_keys`` and ``out_keys`` that indicate217# what values should be expected to be read and written/updated during the218# execution of the module. The class comes with customizable predefined219# values for these attributes to facilitate its construction.220#221# .. note::222# *Usage limitations*: The class supports almost all LSTM features such as223# dropout or multi-layered LSTMs.224# However, to respect TorchRL's conventions, this LSTM must have the ``batch_first``225# attribute set to ``True`` which is **not** the default in PyTorch. However,226# our :class:`~torchrl.modules.LSTMModule` changes this default227# behavior, so we're good with a native call.228#229# Also, the LSTM cannot have a ``bidirectional`` attribute set to ``True`` as230# this wouldn't be usable in online settings. In this case, the default value231# is the correct one.232#233234lstm = LSTMModule(235input_size=n_cells,236hidden_size=128,237device=device,238in_key="embed",239out_key="embed",240)241242######################################################################243# Let us look at the LSTM Module class, specifically its in and out_keys:244print("in_keys", lstm.in_keys)245print("out_keys", lstm.out_keys)246247######################################################################248# We can see that these values contain the key we indicated as the in_key (and out_key)249# as well as recurrent key names. The out_keys are preceded by a "next" prefix250# that indicates that they will need to be written in the "next" TensorDict.251# We use this convention (which can be overridden by passing the in_keys/out_keys252# arguments) to make sure that a call to :func:`~torchrl.envs.utils.step_mdp` will253# move the recurrent state to the root TensorDict, making it available to the254# RNN during the following call (see figure in the intro).255#256# As mentioned earlier, we have one more optional transform to add to our257# environment to make sure that the recurrent states are passed to the buffer.258# The :meth:`~torchrl.modules.LSTMModule.make_tensordict_primer` method does259# exactly that:260#261env.append_transform(lstm.make_tensordict_primer())262263######################################################################264# and that's it! We can print the environment to check that everything looks good now265# that we have added the primer:266print(env)267268######################################################################269# MLP270# ~~~271#272# We use a single-layer MLP to represent the action values we'll be using for273# our policy.274#275mlp = MLP(276out_features=2,277num_cells=[27864,279],280device=device,281)282######################################################################283# and fill the bias with zeros:284285mlp[-1].bias.data.fill_(0.0)286mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"])287288######################################################################289# Using the Q-Values to select an action290# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~291#292# The last part of our policy is the Q-Value Module.293# The Q-Value module :class:`~torchrl.modules.tensordict_module.QValueModule`294# will read the ``"action_values"`` key that is produced by our MLP and295# from it, gather the action that has the maximum value.296# The only thing we need to do is to specify the action space, which can be done297# either by passing a string or an action-spec. This allows us to use298# Categorical (sometimes called "sparse") encoding or the one-hot version of it.299#300qval = QValueModule(spec=env.action_spec)301302######################################################################303# .. note::304# TorchRL also provides a wrapper class :class:`torchrl.modules.QValueActor` that305# wraps a module in a Sequential together with a :class:`~torchrl.modules.tensordict_module.QValueModule`306# like we are doing explicitly here. There is little advantage to do this307# and the process is less transparent, but the end results will be similar to308# what we do here.309#310# We can now put things together in a :class:`~tensordict.nn.TensorDictSequential`311#312stoch_policy = Seq(feature, lstm, mlp, qval)313314######################################################################315# DQN being a deterministic algorithm, exploration is a crucial part of it.316# We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying317# progressively to 0.318# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step`319# (see training loop below).320#321exploration_module = EGreedyModule(322annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2323)324stoch_policy = Seq(325stoch_policy,326exploration_module,327)328329######################################################################330# Using the model for the loss331# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~332#333# The model as we've built it is well equipped to be used in sequential settings.334# However, the class :class:`torch.nn.LSTM` can use a cuDNN-optimized backend335# to run the RNN sequence faster on GPU device. We would not want to miss336# such an opportunity to speed up our training loop!337# To use it, we just need to tell the LSTM module to run on "recurrent-mode"338# when used by the loss.339# As we'll usually want to have two copies of the LSTM module, we do this by340# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that341# will return a new instance of the LSTM (with shared weights) that will342# assume that the input data is sequential in nature.343#344policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)345346######################################################################347# Because we still have a couple of uninitialized parameters we should348# initialize them before creating an optimizer and such.349#350policy(env.reset())351352######################################################################353# DQN Loss354# --------355#356# Out DQN loss requires us to pass the policy and, again, the action-space.357# While this may seem redundant, it is important as we want to make sure that358# the :class:`~torchrl.objectives.DQNLoss` and the :class:`~torchrl.modules.tensordict_module.QValueModule`359# classes are compatible, but aren't strongly dependent on each other.360#361# To use the Double-DQN, we ask for a ``delay_value`` argument that will362# create a non-differentiable copy of the network parameters to be used363# as a target network.364loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)365366######################################################################367# Since we are using a double DQN, we need to update the target parameters.368# We'll use a :class:`~torchrl.objectives.SoftUpdate` instance to carry out369# this work.370#371updater = SoftUpdate(loss_fn, eps=0.95)372373optim = torch.optim.Adam(policy.parameters(), lr=3e-4)374375######################################################################376# Collector and replay buffer377# ---------------------------378#379# We build the simplest data collector there is. We'll try to train our algorithm380# with a million frames, extending the buffer with 50 frames at a time. The buffer381# will be designed to store 20 thousands trajectories of 50 steps each.382# At each optimization step (16 per data collection), we'll collect 4 items383# from our buffer, for a total of 200 transitions.384# We'll use a :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` storage to keep the data385# on disk.386#387# .. note::388# For the sake of efficiency, we're only running a few thousands iterations389# here. In a real setting, the total number of frames should be set to 1M.390#391collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)392rb = TensorDictReplayBuffer(393storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10394)395396######################################################################397# Training loop398# -------------399#400# To keep track of the progress, we will run the policy in the environment once401# every 50 data collection, and plot the results after training.402#403404utd = 16405pbar = tqdm.tqdm(total=1_000_000)406longest = 0407408traj_lens = []409for i, data in enumerate(collector):410if i == 0:411print(412"Let us print the first batch of data.\nPay attention to the key names "413"which will reflect what can be found in this data structure, in particular: "414"the output of the QValueModule (action_values, action and chosen_action_value),"415"the 'is_init' key that will tell us if a step is initial or not, and the "416"recurrent_state keys.\n",417data,418)419pbar.update(data.numel())420# it is important to pass data that is not flattened421rb.extend(data.unsqueeze(0).to_tensordict().cpu())422for _ in range(utd):423s = rb.sample().to(device, non_blocking=True)424loss_vals = loss_fn(s)425loss_vals["loss"].backward()426optim.step()427optim.zero_grad()428longest = max(longest, data["step_count"].max().item())429pbar.set_description(430f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"431)432exploration_module.step(data.numel())433updater.step()434435with set_exploration_type(ExplorationType.MODE), torch.no_grad():436rollout = env.rollout(10000, stoch_policy)437traj_lens.append(rollout.get(("next", "step_count")).max().item())438439######################################################################440# Let's plot our results:441#442if traj_lens:443from matplotlib import pyplot as plt444445plt.plot(traj_lens)446plt.xlabel("Test collection")447plt.title("Test trajectory lengths")448449######################################################################450# Conclusion451# ----------452#453# We have seen how an RNN can be incorporated in a policy in TorchRL.454# You should now be able:455#456# - Create an LSTM module that acts as a :class:`~tensordict.nn.TensorDictModule`457# - Indicate to the LSTM module that a reset is needed via an :class:`~torchrl.envs.transforms.InitTracker`458# transform459# - Incorporate this module in a policy and in a loss module460# - Make sure that the collector is made aware of the recurrent state entries461# such that they can be stored in the replay buffer along with the rest of462# the data463#464# Further Reading465# ---------------466#467# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.468469470