CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/dqn_with_rnn_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
3
"""
4
Recurrent DQN: Training recurrent policies
5
==========================================
6
7
**Author**: `Vincent Moens <https://github.com/vmoens>`_
8
9
.. grid:: 2
10
11
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12
:class-card: card-prerequisites
13
14
* How to incorporating an RNN in an actor in TorchRL
15
* How to use that memory-based policy with a replay buffer and a loss module
16
17
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18
:class-card: card-prerequisites
19
20
* PyTorch v2.0.0
21
* gym[mujoco]
22
* tqdm
23
"""
24
25
#########################################################################
26
# Overview
27
# --------
28
#
29
# Memory-based policies are crucial not only when the observations are partially
30
# observable but also when the time dimension must be taken into account to
31
# make informed decisions.
32
#
33
# Recurrent neural network have long been a popular tool for memory-based
34
# policies. The idea is to keep a recurrent state in memory between two
35
# consecutive steps, and use this as an input to the policy along with the
36
# current observation.
37
#
38
# This tutorial shows how to incorporate an RNN in a policy using TorchRL.
39
#
40
# Key learnings:
41
#
42
# - Incorporating an RNN in an actor in TorchRL;
43
# - Using that memory-based policy with a replay buffer and a loss module.
44
#
45
# The core idea of using RNNs in TorchRL is to use TensorDict as a data carrier
46
# for the hidden states from one step to another. We'll build a policy that
47
# reads the previous recurrent state from the current TensorDict, and writes the
48
# current recurrent states in the TensorDict of the next state:
49
#
50
# .. figure:: /_static/img/rollout_recurrent.png
51
# :alt: Data collection with a recurrent policy
52
#
53
# As this figure shows, our environment populates the TensorDict with zeroed recurrent
54
# states which are read by the policy together with the observation to produce an
55
# action, and recurrent states that will be used for the next step.
56
# When the :func:`~torchrl.envs.utils.step_mdp` function is called, the recurrent states
57
# from the next state are brought to the current TensorDict. Let's see how this
58
# is implemented in practice.
59
60
######################################################################
61
# If you are running this in Google Colab, make sure you install the following dependencies:
62
#
63
# .. code-block:: bash
64
#
65
# !pip3 install torchrl
66
# !pip3 install gym[mujoco]
67
# !pip3 install tqdm
68
#
69
# Setup
70
# -----
71
#
72
73
# sphinx_gallery_start_ignore
74
import warnings
75
76
warnings.filterwarnings("ignore")
77
from torch import multiprocessing
78
79
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
80
# `__main__` method call, but for the easy of reading the code switch to fork
81
# which is also a default spawn method in Google's Colaboratory
82
try:
83
multiprocessing.set_start_method("fork")
84
except RuntimeError:
85
pass
86
87
# sphinx_gallery_end_ignore
88
89
import torch
90
import tqdm
91
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
92
from torch import nn
93
from torchrl.collectors import SyncDataCollector
94
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
95
from torchrl.envs import (
96
Compose,
97
ExplorationType,
98
GrayScale,
99
InitTracker,
100
ObservationNorm,
101
Resize,
102
RewardScaling,
103
set_exploration_type,
104
StepCounter,
105
ToTensorImage,
106
TransformedEnv,
107
)
108
from torchrl.envs.libs.gym import GymEnv
109
from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule
110
from torchrl.objectives import DQNLoss, SoftUpdate
111
112
is_fork = multiprocessing.get_start_method() == "fork"
113
device = (
114
torch.device(0)
115
if torch.cuda.is_available() and not is_fork
116
else torch.device("cpu")
117
)
118
119
######################################################################
120
# Environment
121
# -----------
122
#
123
# As usual, the first step is to build our environment: it helps us
124
# define the problem and build the policy network accordingly. For this tutorial,
125
# we'll be running a single pixel-based instance of the CartPole gym
126
# environment with some custom transforms: turning to grayscale, resizing to
127
# 84x84, scaling down the rewards and normalizing the observations.
128
#
129
# .. note::
130
# The :class:`~torchrl.envs.transforms.StepCounter` transform is accessory. Since the CartPole
131
# task goal is to make trajectories as long as possible, counting the steps
132
# can help us track the performance of our policy.
133
#
134
# Two transforms are important for the purpose of this tutorial:
135
#
136
# - :class:`~torchrl.envs.transforms.InitTracker` will stamp the
137
# calls to :meth:`~torchrl.envs.EnvBase.reset` by adding a ``"is_init"``
138
# boolean mask in the TensorDict that will track which steps require a reset
139
# of the RNN hidden states.
140
# - The :class:`~torchrl.envs.transforms.TensorDictPrimer` transform is a bit more
141
# technical. It is not required to use RNN policies. However, it
142
# instructs the environment (and subsequently the collector) that some extra
143
# keys are to be expected. Once added, a call to `env.reset()` will populate
144
# the entries indicated in the primer with zeroed tensors. Knowing that
145
# these tensors are expected by the policy, the collector will pass them on
146
# during collection. Eventually, we'll be storing our hidden states in the
147
# replay buffer, which will help us bootstrap the computation of the
148
# RNN operations in the loss module (which would otherwise be initiated
149
# with 0s). In summary: not including this transform will not impact hugely
150
# the training of our policy, but it will make the recurrent keys disappear
151
# from the collected data and the replay buffer, which will in turn lead to
152
# a slightly less optimal training.
153
# Fortunately, the :class:`~torchrl.modules.LSTMModule` we propose is
154
# equipped with a helper method to build just that transform for us, so
155
# we can wait until we build it!
156
#
157
158
env = TransformedEnv(
159
GymEnv("CartPole-v1", from_pixels=True, device=device),
160
Compose(
161
ToTensorImage(),
162
GrayScale(),
163
Resize(84, 84),
164
StepCounter(),
165
InitTracker(),
166
RewardScaling(loc=0.0, scale=0.1),
167
ObservationNorm(standard_normal=True, in_keys=["pixels"]),
168
),
169
)
170
171
######################################################################
172
# As always, we need to initialize manually our normalization constants:
173
#
174
env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])
175
td = env.reset()
176
177
######################################################################
178
# Policy
179
# ------
180
#
181
# Our policy will have 3 components: a :class:`~torchrl.modules.ConvNet`
182
# backbone, an :class:`~torchrl.modules.LSTMModule` memory layer and a shallow
183
# :class:`~torchrl.modules.MLP` block that will map the LSTM output onto the
184
# action values.
185
#
186
# Convolutional network
187
# ~~~~~~~~~~~~~~~~~~~~~
188
#
189
# We build a convolutional network flanked with a :class:`torch.nn.AdaptiveAvgPool2d`
190
# that will squash the output in a vector of size 64. The :class:`~torchrl.modules.ConvNet`
191
# can assist us with this:
192
#
193
194
feature = Mod(
195
ConvNet(
196
num_cells=[32, 32, 64],
197
squeeze_output=True,
198
aggregator_class=nn.AdaptiveAvgPool2d,
199
aggregator_kwargs={"output_size": (1, 1)},
200
device=device,
201
),
202
in_keys=["pixels"],
203
out_keys=["embed"],
204
)
205
######################################################################
206
# we execute the first module on a batch of data to gather the size of the
207
# output vector:
208
#
209
n_cells = feature(env.reset())["embed"].shape[-1]
210
211
######################################################################
212
# LSTM Module
213
# ~~~~~~~~~~~
214
#
215
# TorchRL provides a specialized :class:`~torchrl.modules.LSTMModule` class
216
# to incorporate LSTMs in your code-base. It is a :class:`~tensordict.nn.TensorDictModuleBase`
217
# subclass: as such, it has a set of ``in_keys`` and ``out_keys`` that indicate
218
# what values should be expected to be read and written/updated during the
219
# execution of the module. The class comes with customizable predefined
220
# values for these attributes to facilitate its construction.
221
#
222
# .. note::
223
# *Usage limitations*: The class supports almost all LSTM features such as
224
# dropout or multi-layered LSTMs.
225
# However, to respect TorchRL's conventions, this LSTM must have the ``batch_first``
226
# attribute set to ``True`` which is **not** the default in PyTorch. However,
227
# our :class:`~torchrl.modules.LSTMModule` changes this default
228
# behavior, so we're good with a native call.
229
#
230
# Also, the LSTM cannot have a ``bidirectional`` attribute set to ``True`` as
231
# this wouldn't be usable in online settings. In this case, the default value
232
# is the correct one.
233
#
234
235
lstm = LSTMModule(
236
input_size=n_cells,
237
hidden_size=128,
238
device=device,
239
in_key="embed",
240
out_key="embed",
241
)
242
243
######################################################################
244
# Let us look at the LSTM Module class, specifically its in and out_keys:
245
print("in_keys", lstm.in_keys)
246
print("out_keys", lstm.out_keys)
247
248
######################################################################
249
# We can see that these values contain the key we indicated as the in_key (and out_key)
250
# as well as recurrent key names. The out_keys are preceded by a "next" prefix
251
# that indicates that they will need to be written in the "next" TensorDict.
252
# We use this convention (which can be overridden by passing the in_keys/out_keys
253
# arguments) to make sure that a call to :func:`~torchrl.envs.utils.step_mdp` will
254
# move the recurrent state to the root TensorDict, making it available to the
255
# RNN during the following call (see figure in the intro).
256
#
257
# As mentioned earlier, we have one more optional transform to add to our
258
# environment to make sure that the recurrent states are passed to the buffer.
259
# The :meth:`~torchrl.modules.LSTMModule.make_tensordict_primer` method does
260
# exactly that:
261
#
262
env.append_transform(lstm.make_tensordict_primer())
263
264
######################################################################
265
# and that's it! We can print the environment to check that everything looks good now
266
# that we have added the primer:
267
print(env)
268
269
######################################################################
270
# MLP
271
# ~~~
272
#
273
# We use a single-layer MLP to represent the action values we'll be using for
274
# our policy.
275
#
276
mlp = MLP(
277
out_features=2,
278
num_cells=[
279
64,
280
],
281
device=device,
282
)
283
######################################################################
284
# and fill the bias with zeros:
285
286
mlp[-1].bias.data.fill_(0.0)
287
mlp = Mod(mlp, in_keys=["embed"], out_keys=["action_value"])
288
289
######################################################################
290
# Using the Q-Values to select an action
291
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
292
#
293
# The last part of our policy is the Q-Value Module.
294
# The Q-Value module :class:`~torchrl.modules.tensordict_module.QValueModule`
295
# will read the ``"action_values"`` key that is produced by our MLP and
296
# from it, gather the action that has the maximum value.
297
# The only thing we need to do is to specify the action space, which can be done
298
# either by passing a string or an action-spec. This allows us to use
299
# Categorical (sometimes called "sparse") encoding or the one-hot version of it.
300
#
301
qval = QValueModule(spec=env.action_spec)
302
303
######################################################################
304
# .. note::
305
# TorchRL also provides a wrapper class :class:`torchrl.modules.QValueActor` that
306
# wraps a module in a Sequential together with a :class:`~torchrl.modules.tensordict_module.QValueModule`
307
# like we are doing explicitly here. There is little advantage to do this
308
# and the process is less transparent, but the end results will be similar to
309
# what we do here.
310
#
311
# We can now put things together in a :class:`~tensordict.nn.TensorDictSequential`
312
#
313
stoch_policy = Seq(feature, lstm, mlp, qval)
314
315
######################################################################
316
# DQN being a deterministic algorithm, exploration is a crucial part of it.
317
# We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying
318
# progressively to 0.
319
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step`
320
# (see training loop below).
321
#
322
exploration_module = EGreedyModule(
323
annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
324
)
325
stoch_policy = Seq(
326
stoch_policy,
327
exploration_module,
328
)
329
330
######################################################################
331
# Using the model for the loss
332
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
333
#
334
# The model as we've built it is well equipped to be used in sequential settings.
335
# However, the class :class:`torch.nn.LSTM` can use a cuDNN-optimized backend
336
# to run the RNN sequence faster on GPU device. We would not want to miss
337
# such an opportunity to speed up our training loop!
338
# To use it, we just need to tell the LSTM module to run on "recurrent-mode"
339
# when used by the loss.
340
# As we'll usually want to have two copies of the LSTM module, we do this by
341
# calling a :meth:`~torchrl.modules.LSTMModule.set_recurrent_mode` method that
342
# will return a new instance of the LSTM (with shared weights) that will
343
# assume that the input data is sequential in nature.
344
#
345
policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)
346
347
######################################################################
348
# Because we still have a couple of uninitialized parameters we should
349
# initialize them before creating an optimizer and such.
350
#
351
policy(env.reset())
352
353
######################################################################
354
# DQN Loss
355
# --------
356
#
357
# Out DQN loss requires us to pass the policy and, again, the action-space.
358
# While this may seem redundant, it is important as we want to make sure that
359
# the :class:`~torchrl.objectives.DQNLoss` and the :class:`~torchrl.modules.tensordict_module.QValueModule`
360
# classes are compatible, but aren't strongly dependent on each other.
361
#
362
# To use the Double-DQN, we ask for a ``delay_value`` argument that will
363
# create a non-differentiable copy of the network parameters to be used
364
# as a target network.
365
loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)
366
367
######################################################################
368
# Since we are using a double DQN, we need to update the target parameters.
369
# We'll use a :class:`~torchrl.objectives.SoftUpdate` instance to carry out
370
# this work.
371
#
372
updater = SoftUpdate(loss_fn, eps=0.95)
373
374
optim = torch.optim.Adam(policy.parameters(), lr=3e-4)
375
376
######################################################################
377
# Collector and replay buffer
378
# ---------------------------
379
#
380
# We build the simplest data collector there is. We'll try to train our algorithm
381
# with a million frames, extending the buffer with 50 frames at a time. The buffer
382
# will be designed to store 20 thousands trajectories of 50 steps each.
383
# At each optimization step (16 per data collection), we'll collect 4 items
384
# from our buffer, for a total of 200 transitions.
385
# We'll use a :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` storage to keep the data
386
# on disk.
387
#
388
# .. note::
389
# For the sake of efficiency, we're only running a few thousands iterations
390
# here. In a real setting, the total number of frames should be set to 1M.
391
#
392
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
393
rb = TensorDictReplayBuffer(
394
storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
395
)
396
397
######################################################################
398
# Training loop
399
# -------------
400
#
401
# To keep track of the progress, we will run the policy in the environment once
402
# every 50 data collection, and plot the results after training.
403
#
404
405
utd = 16
406
pbar = tqdm.tqdm(total=1_000_000)
407
longest = 0
408
409
traj_lens = []
410
for i, data in enumerate(collector):
411
if i == 0:
412
print(
413
"Let us print the first batch of data.\nPay attention to the key names "
414
"which will reflect what can be found in this data structure, in particular: "
415
"the output of the QValueModule (action_values, action and chosen_action_value),"
416
"the 'is_init' key that will tell us if a step is initial or not, and the "
417
"recurrent_state keys.\n",
418
data,
419
)
420
pbar.update(data.numel())
421
# it is important to pass data that is not flattened
422
rb.extend(data.unsqueeze(0).to_tensordict().cpu())
423
for _ in range(utd):
424
s = rb.sample().to(device, non_blocking=True)
425
loss_vals = loss_fn(s)
426
loss_vals["loss"].backward()
427
optim.step()
428
optim.zero_grad()
429
longest = max(longest, data["step_count"].max().item())
430
pbar.set_description(
431
f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
432
)
433
exploration_module.step(data.numel())
434
updater.step()
435
436
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
437
rollout = env.rollout(10000, stoch_policy)
438
traj_lens.append(rollout.get(("next", "step_count")).max().item())
439
440
######################################################################
441
# Let's plot our results:
442
#
443
if traj_lens:
444
from matplotlib import pyplot as plt
445
446
plt.plot(traj_lens)
447
plt.xlabel("Test collection")
448
plt.title("Test trajectory lengths")
449
450
######################################################################
451
# Conclusion
452
# ----------
453
#
454
# We have seen how an RNN can be incorporated in a policy in TorchRL.
455
# You should now be able:
456
#
457
# - Create an LSTM module that acts as a :class:`~tensordict.nn.TensorDictModule`
458
# - Indicate to the LSTM module that a reset is needed via an :class:`~torchrl.envs.transforms.InitTracker`
459
# transform
460
# - Incorporate this module in a policy and in a loss module
461
# - Make sure that the collector is made aware of the recurrent state entries
462
# such that they can be stored in the replay buffer along with the rest of
463
# the data
464
#
465
# Further Reading
466
# ---------------
467
#
468
# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.
469
470