CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/coding_ddpg.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
TorchRL objectives: Coding a DDPG loss
4
======================================
5
**Author**: `Vincent Moens <https://github.com/vmoens>`_
6
7
"""
8
9
##############################################################################
10
# Overview
11
# --------
12
#
13
# TorchRL separates the training of RL algorithms in various pieces that will be
14
# assembled in your training script: the environment, the data collection and
15
# storage, the model and finally the loss function.
16
#
17
# TorchRL losses (or "objectives") are stateful objects that contain the
18
# trainable parameters (policy and value models).
19
# This tutorial will guide you through the steps to code a loss from the ground up
20
# using TorchRL.
21
#
22
# To this aim, we will be focusing on DDPG, which is a relatively straightforward
23
# algorithm to code.
24
# `Deep Deterministic Policy Gradient <https://arxiv.org/abs/1509.02971>`_ (DDPG)
25
# is a simple continuous control algorithm. It consists in learning a
26
# parametric value function for an action-observation pair, and
27
# then learning a policy that outputs actions that maximize this value
28
# function given a certain observation.
29
#
30
# What you will learn:
31
#
32
# - how to write a loss module and customize its value estimator;
33
# - how to build an environment in TorchRL, including transforms
34
# (for example, data normalization) and parallel execution;
35
# - how to design a policy and value network;
36
# - how to collect data from your environment efficiently and store them
37
# in a replay buffer;
38
# - how to store trajectories (and not transitions) in your replay buffer);
39
# - how to evaluate your model.
40
#
41
# Prerequisites
42
# ~~~~~~~~~~~~~
43
#
44
# This tutorial assumes that you have completed the
45
# `PPO tutorial <reinforcement_ppo.html>`_ which gives
46
# an overview of the TorchRL components and dependencies, such as
47
# :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`,
48
# although it should be
49
# sufficiently transparent to be understood without a deep understanding of
50
# these classes.
51
#
52
# .. note::
53
# We do not aim at giving a SOTA implementation of the algorithm, but rather
54
# to provide a high-level illustration of TorchRL's loss implementations
55
# and the library features that are to be used in the context of
56
# this algorithm.
57
#
58
# Imports and setup
59
# -----------------
60
#
61
# .. code-block:: bash
62
#
63
# %%bash
64
# pip3 install torchrl mujoco glfw
65
66
# sphinx_gallery_start_ignore
67
import warnings
68
69
warnings.filterwarnings("ignore")
70
from torch import multiprocessing
71
72
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
73
# `__main__` method call, but for the easy of reading the code switch to fork
74
# which is also a default spawn method in Google's Colaboratory
75
try:
76
multiprocessing.set_start_method("fork")
77
except RuntimeError:
78
pass
79
80
# sphinx_gallery_end_ignore
81
82
83
import torch
84
import tqdm
85
86
87
###############################################################################
88
# We will execute the policy on CUDA if available
89
is_fork = multiprocessing.get_start_method() == "fork"
90
device = (
91
torch.device(0)
92
if torch.cuda.is_available() and not is_fork
93
else torch.device("cpu")
94
)
95
collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA
96
97
###############################################################################
98
# TorchRL :class:`~torchrl.objectives.LossModule`
99
# -----------------------------------------------
100
#
101
# TorchRL provides a series of losses to use in your training scripts.
102
# The aim is to have losses that are easily reusable/swappable and that have
103
# a simple signature.
104
#
105
# The main characteristics of TorchRL losses are:
106
#
107
# - They are stateful objects: they contain a copy of the trainable parameters
108
# such that ``loss_module.parameters()`` gives whatever is needed to train the
109
# algorithm.
110
# - They follow the ``TensorDict`` convention: the :meth:`torch.nn.Module.forward`
111
# method will receive a TensorDict as input that contains all the necessary
112
# information to return a loss value.
113
#
114
# >>> data = replay_buffer.sample()
115
# >>> loss_dict = loss_module(data)
116
#
117
# - They output a :class:`tensordict.TensorDict` instance with the loss values
118
# written under a ``"loss_<smth>"`` where ``smth`` is a string describing the
119
# loss. Additional keys in the ``TensorDict`` may be useful metrics to log during
120
# training time.
121
#
122
# .. note::
123
# The reason we return independent losses is to let the user use a different
124
# optimizer for different sets of parameters for instance. Summing the losses
125
# can be simply done via
126
#
127
# >>> loss_val = sum(loss for key, loss in loss_dict.items() if key.startswith("loss_"))
128
#
129
# The ``__init__`` method
130
# ~~~~~~~~~~~~~~~~~~~~~~~
131
#
132
# The parent class of all losses is :class:`~torchrl.objectives.LossModule`.
133
# As many other components of the library, its :meth:`~torchrl.objectives.LossModule.forward` method expects
134
# as input a :class:`tensordict.TensorDict` instance sampled from an experience
135
# replay buffer, or any similar data structure. Using this format makes it
136
# possible to re-use the module across
137
# modalities, or in complex settings where the model needs to read multiple
138
# entries for instance. In other words, it allows us to code a loss module that
139
# is oblivious to the data type that is being given to is and that focuses on
140
# running the elementary steps of the loss function and only those.
141
#
142
# To keep the tutorial as didactic as we can, we'll be displaying each method
143
# of the class independently and we'll be populating the class at a later
144
# stage.
145
#
146
# Let us start with the :meth:`~torchrl.objectives.LossModule.__init__`
147
# method. DDPG aims at solving a control task with a simple strategy:
148
# training a policy to output actions that maximize the value predicted by
149
# a value network. Hence, our loss module needs to receive two networks in its
150
# constructor: an actor and a value networks. We expect both of these to be
151
# TensorDict-compatible objects, such as
152
# :class:`tensordict.nn.TensorDictModule`.
153
# Our loss function will need to compute a target value and fit the value
154
# network to this, and generate an action and fit the policy such that its
155
# value estimate is maximized.
156
#
157
# The crucial step of the :meth:`LossModule.__init__` method is the call to
158
# :meth:`~torchrl.LossModule.convert_to_functional`. This method will extract
159
# the parameters from the module and convert it to a functional module.
160
# Strictly speaking, this is not necessary and one may perfectly code all
161
# the losses without it. However, we encourage its usage for the following
162
# reason.
163
#
164
# The reason TorchRL does this is that RL algorithms often execute the same
165
# model with different sets of parameters, called "trainable" and "target"
166
# parameters.
167
# The "trainable" parameters are those that the optimizer needs to fit. The
168
# "target" parameters are usually a copy of the former's with some time lag
169
# (absolute or diluted through a moving average).
170
# These target parameters are used to compute the value associated with the
171
# next observation. One the advantages of using a set of target parameters
172
# for the value model that do not match exactly the current configuration is
173
# that they provide a pessimistic bound on the value function being computed.
174
# Pay attention to the ``create_target_params`` keyword argument below: this
175
# argument tells the :meth:`~torchrl.objectives.LossModule.convert_to_functional`
176
# method to create a set of target parameters in the loss module to be used
177
# for target value computation. If this is set to ``False`` (see the actor network
178
# for instance) the ``target_actor_network_params`` attribute will still be
179
# accessible but this will just return a **detached** version of the
180
# actor parameters.
181
#
182
# Later, we will see how the target parameters should be updated in TorchRL.
183
#
184
185
from tensordict.nn import TensorDictModule, TensorDictSequential
186
187
188
def _init(
189
self,
190
actor_network: TensorDictModule,
191
value_network: TensorDictModule,
192
) -> None:
193
super(type(self), self).__init__()
194
195
self.convert_to_functional(
196
actor_network,
197
"actor_network",
198
create_target_params=True,
199
)
200
self.convert_to_functional(
201
value_network,
202
"value_network",
203
create_target_params=True,
204
compare_against=list(actor_network.parameters()),
205
)
206
207
self.actor_in_keys = actor_network.in_keys
208
209
# Since the value we'll be using is based on the actor and value network,
210
# we put them together in a single actor-critic container.
211
actor_critic = ActorCriticWrapper(actor_network, value_network)
212
self.actor_critic = actor_critic
213
self.loss_function = "l2"
214
215
216
###############################################################################
217
# The value estimator loss method
218
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
219
#
220
# In many RL algorithm, the value network (or Q-value network) is trained based
221
# on an empirical value estimate. This can be bootstrapped (TD(0), low
222
# variance, high bias), meaning
223
# that the target value is obtained using the next reward and nothing else, or
224
# a Monte-Carlo estimate can be obtained (TD(1)) in which case the whole
225
# sequence of upcoming rewards will be used (high variance, low bias). An
226
# intermediate estimator (TD(:math:`\lambda`)) can also be used to compromise
227
# bias and variance.
228
# TorchRL makes it easy to use one or the other estimator via the
229
# :class:`~torchrl.objectives.utils.ValueEstimators` Enum class, which contains
230
# pointers to all the value estimators implemented. Let us define the default
231
# value function here. We will take the simplest version (TD(0)), and show later
232
# on how this can be changed.
233
234
from torchrl.objectives.utils import ValueEstimators
235
236
default_value_estimator = ValueEstimators.TD0
237
238
###############################################################################
239
# We also need to give some instructions to DDPG on how to build the value
240
# estimator, depending on the user query. Depending on the estimator provided,
241
# we will build the corresponding module to be used at train time:
242
243
from torchrl.objectives.utils import default_value_kwargs
244
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
245
246
247
def make_value_estimator(self, value_type: ValueEstimators, **hyperparams):
248
hp = dict(default_value_kwargs(value_type))
249
if hasattr(self, "gamma"):
250
hp["gamma"] = self.gamma
251
hp.update(hyperparams)
252
value_key = "state_action_value"
253
if value_type == ValueEstimators.TD1:
254
self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
255
elif value_type == ValueEstimators.TD0:
256
self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
257
elif value_type == ValueEstimators.GAE:
258
raise NotImplementedError(
259
f"Value type {value_type} it not implemented for loss {type(self)}."
260
)
261
elif value_type == ValueEstimators.TDLambda:
262
self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp)
263
else:
264
raise NotImplementedError(f"Unknown value type {value_type}")
265
self._value_estimator.set_keys(value=value_key)
266
267
268
###############################################################################
269
# The ``make_value_estimator`` method can but does not need to be called: if
270
# not, the :class:`~torchrl.objectives.LossModule` will query this method with
271
# its default estimator.
272
#
273
# The actor loss method
274
# ~~~~~~~~~~~~~~~~~~~~~
275
#
276
# The central piece of an RL algorithm is the training loss for the actor.
277
# In the case of DDPG, this function is quite simple: we just need to compute
278
# the value associated with an action computed using the policy and optimize
279
# the actor weights to maximize this value.
280
#
281
# When computing this value, we must make sure to take the value parameters out
282
# of the graph, otherwise the actor and value loss will be mixed up.
283
# For this, the :func:`~torchrl.objectives.utils.hold_out_params` function
284
# can be used.
285
286
287
def _loss_actor(
288
self,
289
tensordict,
290
) -> torch.Tensor:
291
td_copy = tensordict.select(*self.actor_in_keys)
292
# Get an action from the actor network: since we made it functional, we need to pass the params
293
with self.actor_network_params.to_module(self.actor_network):
294
td_copy = self.actor_network(td_copy)
295
# get the value associated with that action
296
with self.value_network_params.detach().to_module(self.value_network):
297
td_copy = self.value_network(td_copy)
298
return -td_copy.get("state_action_value")
299
300
301
###############################################################################
302
# The value loss method
303
# ~~~~~~~~~~~~~~~~~~~~~
304
#
305
# We now need to optimize our value network parameters.
306
# To do this, we will rely on the value estimator of our class:
307
#
308
309
from torchrl.objectives.utils import distance_loss
310
311
312
def _loss_value(
313
self,
314
tensordict,
315
):
316
td_copy = tensordict.clone()
317
318
# V(s, a)
319
with self.value_network_params.to_module(self.value_network):
320
self.value_network(td_copy)
321
pred_val = td_copy.get("state_action_value").squeeze(-1)
322
323
# we manually reconstruct the parameters of the actor-critic, where the first
324
# set of parameters belongs to the actor and the second to the value function.
325
target_params = TensorDict(
326
{
327
"module": {
328
"0": self.target_actor_network_params,
329
"1": self.target_value_network_params,
330
}
331
},
332
batch_size=self.target_actor_network_params.batch_size,
333
device=self.target_actor_network_params.device,
334
)
335
with target_params.to_module(self.actor_critic):
336
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
337
338
# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
339
loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)
340
td_error = (pred_val - target_value).pow(2)
341
342
return loss_value, td_error, pred_val, target_value
343
344
345
###############################################################################
346
# Putting things together in a forward call
347
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
348
#
349
# The only missing piece is the forward method, which will glue together the
350
# value and actor loss, collect the cost values and write them in a ``TensorDict``
351
# delivered to the user.
352
353
from tensordict import TensorDict, TensorDictBase
354
355
356
def _forward(self, input_tensordict: TensorDictBase) -> TensorDict:
357
loss_value, td_error, pred_val, target_value = self.loss_value(
358
input_tensordict,
359
)
360
td_error = td_error.detach()
361
td_error = td_error.unsqueeze(input_tensordict.ndimension())
362
if input_tensordict.device is not None:
363
td_error = td_error.to(input_tensordict.device)
364
input_tensordict.set(
365
"td_error",
366
td_error,
367
inplace=True,
368
)
369
loss_actor = self.loss_actor(input_tensordict)
370
return TensorDict(
371
source={
372
"loss_actor": loss_actor.mean(),
373
"loss_value": loss_value.mean(),
374
"pred_value": pred_val.mean().detach(),
375
"target_value": target_value.mean().detach(),
376
"pred_value_max": pred_val.max().detach(),
377
"target_value_max": target_value.max().detach(),
378
},
379
batch_size=[],
380
)
381
382
383
from torchrl.objectives import LossModule
384
385
386
class DDPGLoss(LossModule):
387
default_value_estimator = default_value_estimator
388
make_value_estimator = make_value_estimator
389
390
__init__ = _init
391
forward = _forward
392
loss_value = _loss_value
393
loss_actor = _loss_actor
394
395
396
###############################################################################
397
# Now that we have our loss, we can use it to train a policy to solve a
398
# control task.
399
#
400
# Environment
401
# -----------
402
#
403
# In most algorithms, the first thing that needs to be taken care of is the
404
# construction of the environment as it conditions the remainder of the
405
# training script.
406
#
407
# For this example, we will be using the ``"cheetah"`` task. The goal is to make
408
# a half-cheetah run as fast as possible.
409
#
410
# In TorchRL, one can create such a task by relying on ``dm_control`` or ``gym``:
411
#
412
# .. code-block:: python
413
#
414
# env = GymEnv("HalfCheetah-v4")
415
#
416
# or
417
#
418
# .. code-block:: python
419
#
420
# env = DMControlEnv("cheetah", "run")
421
#
422
# By default, these environment disable rendering. Training from states is
423
# usually easier than training from images. To keep things simple, we focus
424
# on learning from states only. To pass the pixels to the ``tensordicts`` that
425
# are collected by :func:`env.step()`, simply pass the ``from_pixels=True``
426
# argument to the constructor:
427
#
428
# .. code-block:: python
429
#
430
# env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True)
431
#
432
# We write a :func:`make_env` helper function that will create an environment
433
# with either one of the two backends considered above (``dm-control`` or ``gym``).
434
#
435
436
from torchrl.envs.libs.dm_control import DMControlEnv
437
from torchrl.envs.libs.gym import GymEnv
438
439
env_library = None
440
env_name = None
441
442
443
def make_env(from_pixels=False):
444
"""Create a base ``env``."""
445
global env_library
446
global env_name
447
448
if backend == "dm_control":
449
env_name = "cheetah"
450
env_task = "run"
451
env_args = (env_name, env_task)
452
env_library = DMControlEnv
453
elif backend == "gym":
454
env_name = "HalfCheetah-v4"
455
env_args = (env_name,)
456
env_library = GymEnv
457
else:
458
raise NotImplementedError
459
460
env_kwargs = {
461
"device": device,
462
"from_pixels": from_pixels,
463
"pixels_only": from_pixels,
464
"frame_skip": 2,
465
}
466
env = env_library(*env_args, **env_kwargs)
467
return env
468
469
470
###############################################################################
471
# Transforms
472
# ~~~~~~~~~~
473
#
474
# Now that we have a base environment, we may want to modify its representation
475
# to make it more policy-friendly. In TorchRL, transforms are appended to the
476
# base environment in a specialized :class:`torchr.envs.TransformedEnv` class.
477
#
478
# - It is common in DDPG to rescale the reward using some heuristic value. We
479
# will multiply the reward by 5 in this example.
480
#
481
# - If we are using :mod:`dm_control`, it is also important to build an interface
482
# between the simulator which works with double precision numbers, and our
483
# script which presumably uses single precision ones. This transformation goes
484
# both ways: when calling :func:`env.step`, our actions will need to be
485
# represented in double precision, and the output will need to be transformed
486
# to single precision.
487
# The :class:`~torchrl.envs.DoubleToFloat` transform does exactly this: the
488
# ``in_keys`` list refers to the keys that will need to be transformed from
489
# double to float, while the ``in_keys_inv`` refers to those that need to
490
# be transformed to double before being passed to the environment.
491
#
492
# - We concatenate the state keys together using the :class:`~torchrl.envs.CatTensors`
493
# transform.
494
#
495
# - Finally, we also leave the possibility of normalizing the states: we will
496
# take care of computing the normalizing constants later on.
497
#
498
499
from torchrl.envs import (
500
CatTensors,
501
DoubleToFloat,
502
EnvCreator,
503
InitTracker,
504
ObservationNorm,
505
ParallelEnv,
506
RewardScaling,
507
StepCounter,
508
TransformedEnv,
509
)
510
511
512
def make_transformed_env(
513
env,
514
):
515
"""Apply transforms to the ``env`` (such as reward scaling and state normalization)."""
516
517
env = TransformedEnv(env)
518
519
# we append transforms one by one, although we might as well create the
520
# transformed environment using the `env = TransformedEnv(base_env, transforms)`
521
# syntax.
522
env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling))
523
524
# We concatenate all states into a single "observation_vector"
525
# even if there is a single tensor, it'll be renamed in "observation_vector".
526
# This facilitates the downstream operations as we know the name of the
527
# output tensor.
528
# In some environments (not half-cheetah), there may be more than one
529
# observation vector: in this case this code snippet will concatenate them
530
# all.
531
selected_keys = list(env.observation_spec.keys())
532
out_key = "observation_vector"
533
env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))
534
535
# we normalize the states, but for now let's just instantiate a stateless
536
# version of the transform
537
env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True))
538
539
env.append_transform(DoubleToFloat())
540
541
env.append_transform(StepCounter(max_frames_per_traj))
542
543
# We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU)
544
# exploration:
545
env.append_transform(InitTracker())
546
547
return env
548
549
550
###############################################################################
551
# Parallel execution
552
# ~~~~~~~~~~~~~~~~~~
553
#
554
# The following helper function allows us to run environments in parallel.
555
# Running environments in parallel can significantly speed up the collection
556
# throughput. When using transformed environment, we need to choose whether we
557
# want to execute the transform individually for each environment, or
558
# centralize the data and transform it in batch. Both approaches are easy to
559
# code:
560
#
561
# .. code-block:: python
562
#
563
# env = ParallelEnv(
564
# lambda: TransformedEnv(GymEnv("HalfCheetah-v4"), transforms),
565
# num_workers=4
566
# )
567
# env = TransformedEnv(
568
# ParallelEnv(lambda: GymEnv("HalfCheetah-v4"), num_workers=4),
569
# transforms
570
# )
571
#
572
# To leverage the vectorization capabilities of PyTorch, we adopt
573
# the first method:
574
#
575
576
577
def parallel_env_constructor(
578
env_per_collector,
579
transform_state_dict,
580
):
581
if env_per_collector == 1:
582
583
def make_t_env():
584
env = make_transformed_env(make_env())
585
env.transform[2].init_stats(3)
586
env.transform[2].loc.copy_(transform_state_dict["loc"])
587
env.transform[2].scale.copy_(transform_state_dict["scale"])
588
return env
589
590
env_creator = EnvCreator(make_t_env)
591
return env_creator
592
593
parallel_env = ParallelEnv(
594
num_workers=env_per_collector,
595
create_env_fn=EnvCreator(lambda: make_env()),
596
create_env_kwargs=None,
597
pin_memory=False,
598
)
599
env = make_transformed_env(parallel_env)
600
# we call `init_stats` for a limited number of steps, just to instantiate
601
# the lazy buffers.
602
env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1])
603
env.transform[2].load_state_dict(transform_state_dict)
604
return env
605
606
607
# The backend can be ``gym`` or ``dm_control``
608
backend = "gym"
609
610
###############################################################################
611
# .. note::
612
#
613
# ``frame_skip`` batches multiple step together with a single action
614
# If > 1, the other frame counts (for example, frames_per_batch, total_frames)
615
# need to be adjusted to have a consistent total number of frames collected
616
# across experiments. This is important as raising the frame-skip but keeping the
617
# total number of frames unchanged may seem like cheating: all things compared,
618
# a dataset of 10M elements collected with a frame-skip of 2 and another with
619
# a frame-skip of 1 actually have a ratio of interactions with the environment
620
# of 2:1! In a nutshell, one should be cautious about the frame-count of a
621
# training script when dealing with frame skipping as this may lead to
622
# biased comparisons between training strategies.
623
#
624
# Scaling the reward helps us control the signal magnitude for a more
625
# efficient learning.
626
reward_scaling = 5.0
627
628
###############################################################################
629
# We also define when a trajectory will be truncated. A thousand steps (500 if
630
# frame-skip = 2) is a good number to use for the cheetah task:
631
632
max_frames_per_traj = 500
633
634
###############################################################################
635
# Normalization of the observations
636
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
637
#
638
# To compute the normalizing statistics, we run an arbitrary number of random
639
# steps in the environment and compute the mean and standard deviation of the
640
# collected observations. The :func:`ObservationNorm.init_stats()` method can
641
# be used for this purpose. To get the summary statistics, we create a dummy
642
# environment and run it for a given number of steps, collect data over a given
643
# number of steps and compute its summary statistics.
644
#
645
646
647
def get_env_stats():
648
"""Gets the stats of an environment."""
649
proof_env = make_transformed_env(make_env())
650
t = proof_env.transform[2]
651
t.init_stats(init_env_steps)
652
transform_state_dict = t.state_dict()
653
proof_env.close()
654
return transform_state_dict
655
656
657
###############################################################################
658
# Normalization stats
659
# ~~~~~~~~~~~~~~~~~~~
660
# Number of random steps used as for stats computation using ``ObservationNorm``
661
662
init_env_steps = 5000
663
664
transform_state_dict = get_env_stats()
665
666
###############################################################################
667
# Number of environments in each data collector
668
env_per_collector = 4
669
670
###############################################################################
671
# We pass the stats computed earlier to normalize the output of our
672
# environment:
673
674
parallel_env = parallel_env_constructor(
675
env_per_collector=env_per_collector,
676
transform_state_dict=transform_state_dict,
677
)
678
679
680
from torchrl.data import CompositeSpec
681
682
###############################################################################
683
# Building the model
684
# ------------------
685
#
686
# We now turn to the setup of the model. As we have seen, DDPG requires a
687
# value network, trained to estimate the value of a state-action pair, and a
688
# parametric actor that learns how to select actions that maximize this value.
689
#
690
# Recall that building a TorchRL module requires two steps:
691
#
692
# - writing the :class:`torch.nn.Module` that will be used as network,
693
# - wrapping the network in a :class:`tensordict.nn.TensorDictModule` where the
694
# data flow is handled by specifying the input and output keys.
695
#
696
# In more complex scenarios, :class:`tensordict.nn.TensorDictSequential` can
697
# also be used.
698
#
699
#
700
# The Q-Value network is wrapped in a :class:`~torchrl.modules.ValueOperator`
701
# that automatically sets the ``out_keys`` to ``"state_action_value`` for q-value
702
# networks and ``state_value`` for other value networks.
703
#
704
# TorchRL provides a built-in version of the DDPG networks as presented in the
705
# original paper. These can be found under :class:`~torchrl.modules.DdpgMlpActor`
706
# and :class:`~torchrl.modules.DdpgMlpQNet`.
707
#
708
# Since we use lazy modules, it is necessary to materialize the lazy modules
709
# before being able to move the policy from device to device and achieve other
710
# operations. Hence, it is good practice to run the modules with a small
711
# sample of data. For this purpose, we generate fake data from the
712
# environment specs.
713
#
714
715
from torchrl.modules import (
716
ActorCriticWrapper,
717
DdpgMlpActor,
718
DdpgMlpQNet,
719
OrnsteinUhlenbeckProcessModule,
720
ProbabilisticActor,
721
TanhDelta,
722
ValueOperator,
723
)
724
725
726
def make_ddpg_actor(
727
transform_state_dict,
728
device="cpu",
729
):
730
proof_environment = make_transformed_env(make_env())
731
proof_environment.transform[2].init_stats(3)
732
proof_environment.transform[2].load_state_dict(transform_state_dict)
733
734
out_features = proof_environment.action_spec.shape[-1]
735
736
actor_net = DdpgMlpActor(
737
action_dim=out_features,
738
)
739
740
in_keys = ["observation_vector"]
741
out_keys = ["param"]
742
743
actor = TensorDictModule(
744
actor_net,
745
in_keys=in_keys,
746
out_keys=out_keys,
747
)
748
749
actor = ProbabilisticActor(
750
actor,
751
distribution_class=TanhDelta,
752
in_keys=["param"],
753
spec=CompositeSpec(action=proof_environment.action_spec),
754
).to(device)
755
756
q_net = DdpgMlpQNet()
757
758
in_keys = in_keys + ["action"]
759
qnet = ValueOperator(
760
in_keys=in_keys,
761
module=q_net,
762
).to(device)
763
764
# initialize lazy modules
765
qnet(actor(proof_environment.reset().to(device)))
766
return actor, qnet
767
768
769
actor, qnet = make_ddpg_actor(
770
transform_state_dict=transform_state_dict,
771
device=device,
772
)
773
774
###############################################################################
775
# Exploration
776
# ~~~~~~~~~~~
777
#
778
# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`
779
# exploration module, as suggested in the original paper.
780
# Let's define the number of frames before OU noise reaches its minimum value
781
annealing_frames = 1_000_000
782
783
actor_model_explore = TensorDictSequential(
784
actor,
785
OrnsteinUhlenbeckProcessModule(
786
spec=actor.spec.clone(),
787
annealing_num_steps=annealing_frames,
788
).to(device),
789
)
790
if device == torch.device("cpu"):
791
actor_model_explore.share_memory()
792
793
794
###############################################################################
795
# Data collector
796
# --------------
797
#
798
# TorchRL provides specialized classes to help you collect data by executing
799
# the policy in the environment. These "data collectors" iteratively compute
800
# the action to be executed at a given time, then execute a step in the
801
# environment and reset it when required.
802
# Data collectors are designed to help developers have a tight control
803
# on the number of frames per batch of data, on the (a)sync nature of this
804
# collection and on the resources allocated to the data collection (for example
805
# GPU, number of workers, and so on).
806
#
807
# Here we will use
808
# :class:`~torchrl.collectors.SyncDataCollector`, a simple, single-process
809
# data collector. TorchRL offers other collectors, such as
810
# :class:`~torchrl.collectors.MultiaSyncDataCollector`, which executed the
811
# rollouts in an asynchronous manner (for example, data will be collected while
812
# the policy is being optimized, thereby decoupling the training and
813
# data collection).
814
#
815
# The parameters to specify are:
816
#
817
# - an environment factory or an environment,
818
# - the policy,
819
# - the total number of frames before the collector is considered empty,
820
# - the maximum number of frames per trajectory (useful for non-terminating
821
# environments, like ``dm_control`` ones).
822
#
823
# .. note::
824
#
825
# The ``max_frames_per_traj`` passed to the collector will have the effect
826
# of registering a new :class:`~torchrl.envs.StepCounter` transform
827
# with the environment used for inference. We can achieve the same result
828
# manually, as we do in this script.
829
#
830
# One should also pass:
831
#
832
# - the number of frames in each batch collected,
833
# - the number of random steps executed independently from the policy,
834
# - the devices used for policy execution
835
# - the devices used to store data before the data is passed to the main
836
# process.
837
#
838
# The total frames we will use during training should be around 1M.
839
total_frames = 10_000 # 1_000_000
840
841
###############################################################################
842
# The number of frames returned by the collector at each iteration of the outer
843
# loop is equal to the length of each sub-trajectories times the number of
844
# environments run in parallel in each collector.
845
#
846
# In other words, we expect batches from the collector to have a shape
847
# ``[env_per_collector, traj_len]`` where
848
# ``traj_len=frames_per_batch/env_per_collector``:
849
#
850
traj_len = 200
851
frames_per_batch = env_per_collector * traj_len
852
init_random_frames = 5000
853
num_collectors = 2
854
855
from torchrl.collectors import SyncDataCollector
856
from torchrl.envs import ExplorationType
857
858
collector = SyncDataCollector(
859
parallel_env,
860
policy=actor_model_explore,
861
total_frames=total_frames,
862
frames_per_batch=frames_per_batch,
863
init_random_frames=init_random_frames,
864
reset_at_each_iter=False,
865
split_trajs=False,
866
device=collector_device,
867
exploration_type=ExplorationType.RANDOM,
868
)
869
870
###############################################################################
871
# Evaluator: building your recorder object
872
# ----------------------------------------
873
#
874
# As the training data is obtained using some exploration strategy, the true
875
# performance of our algorithm needs to be assessed in deterministic mode. We
876
# do this using a dedicated class, ``Recorder``, which executes the policy in
877
# the environment at a given frequency and returns some statistics obtained
878
# from these simulations.
879
#
880
# The following helper function builds this object:
881
from torchrl.trainers import Recorder
882
883
884
def make_recorder(actor_model_explore, transform_state_dict, record_interval):
885
base_env = make_env()
886
environment = make_transformed_env(base_env)
887
environment.transform[2].init_stats(
888
3
889
) # must be instantiated to load the state dict
890
environment.transform[2].load_state_dict(transform_state_dict)
891
892
recorder_obj = Recorder(
893
record_frames=1000,
894
policy_exploration=actor_model_explore,
895
environment=environment,
896
exploration_type=ExplorationType.MEAN,
897
record_interval=record_interval,
898
)
899
return recorder_obj
900
901
902
###############################################################################
903
# We will be recording the performance every 10 batch collected
904
record_interval = 10
905
906
recorder = make_recorder(
907
actor_model_explore, transform_state_dict, record_interval=record_interval
908
)
909
910
from torchrl.data.replay_buffers import (
911
LazyMemmapStorage,
912
PrioritizedSampler,
913
RandomSampler,
914
TensorDictReplayBuffer,
915
)
916
917
###############################################################################
918
# Replay buffer
919
# -------------
920
#
921
# Replay buffers come in two flavors: prioritized (where some error signal
922
# is used to give a higher likelihood of sampling to some items than others)
923
# and regular, circular experience replay.
924
#
925
# TorchRL replay buffers are composable: one can pick up the storage, sampling
926
# and writing strategies. It is also possible to
927
# store tensors on physical memory using a memory-mapped array. The following
928
# function takes care of creating the replay buffer with the desired
929
# hyperparameters:
930
#
931
932
from torchrl.envs import RandomCropTensorDict
933
934
935
def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb=False):
936
if prb:
937
sampler = PrioritizedSampler(
938
max_capacity=buffer_size,
939
alpha=0.7,
940
beta=0.5,
941
)
942
else:
943
sampler = RandomSampler()
944
replay_buffer = TensorDictReplayBuffer(
945
storage=LazyMemmapStorage(
946
buffer_size,
947
scratch_dir=buffer_scratch_dir,
948
),
949
batch_size=batch_size,
950
sampler=sampler,
951
pin_memory=False,
952
prefetch=prefetch,
953
transform=RandomCropTensorDict(random_crop_len, sample_dim=1),
954
)
955
return replay_buffer
956
957
958
###############################################################################
959
# We'll store the replay buffer in a temporary directory on disk
960
961
import tempfile
962
963
tmpdir = tempfile.TemporaryDirectory()
964
buffer_scratch_dir = tmpdir.name
965
966
###############################################################################
967
# Replay buffer storage and batch size
968
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
969
#
970
# TorchRL replay buffer counts the number of elements along the first dimension.
971
# Since we'll be feeding trajectories to our buffer, we need to adapt the buffer
972
# size by dividing it by the length of the sub-trajectories yielded by our
973
# data collector.
974
# Regarding the batch-size, our sampling strategy will consist in sampling
975
# trajectories of length ``traj_len=200`` before selecting sub-trajectories
976
# or length ``random_crop_len=25`` on which the loss will be computed.
977
# This strategy balances the choice of storing whole trajectories of a certain
978
# length with the need for providing samples with a sufficient heterogeneity
979
# to our loss. The following figure shows the dataflow from a collector
980
# that gets 8 frames in each batch with 2 environments run in parallel,
981
# feeds them to a replay buffer that contains 1000 trajectories and
982
# samples sub-trajectories of 2 time steps each.
983
#
984
# .. figure:: /_static/img/replaybuffer_traj.png
985
# :alt: Storing trajectories in the replay buffer
986
#
987
# Let's start with the number of frames stored in the buffer
988
989
990
def ceil_div(x, y):
991
return -x // (-y)
992
993
994
buffer_size = 1_000_000
995
buffer_size = ceil_div(buffer_size, traj_len)
996
997
###############################################################################
998
# Prioritized replay buffer is disabled by default
999
prb = False
1000
1001
###############################################################################
1002
# We also need to define how many updates we'll be doing per batch of data
1003
# collected. This is known as the update-to-data or ``UTD`` ratio:
1004
update_to_data = 64
1005
1006
###############################################################################
1007
# We'll be feeding the loss with trajectories of length 25:
1008
random_crop_len = 25
1009
1010
###############################################################################
1011
# In the original paper, the authors perform one update with a batch of 64
1012
# elements for each frame collected. Here, we reproduce the same ratio
1013
# but while realizing several updates at each batch collection. We
1014
# adapt our batch-size to achieve the same number of update-per-frame ratio:
1015
1016
batch_size = ceil_div(64 * frames_per_batch, update_to_data * random_crop_len)
1017
1018
replay_buffer = make_replay_buffer(
1019
buffer_size=buffer_size,
1020
batch_size=batch_size,
1021
random_crop_len=random_crop_len,
1022
prefetch=3,
1023
prb=prb,
1024
)
1025
1026
###############################################################################
1027
# Loss module construction
1028
# ------------------------
1029
#
1030
# We build our loss module with the actor and ``qnet`` we've just created.
1031
# Because we have target parameters to update, we _must_ create a target network
1032
# updater.
1033
#
1034
1035
gamma = 0.99
1036
lmbda = 0.9
1037
tau = 0.001 # Decay factor for the target network
1038
1039
loss_module = DDPGLoss(actor, qnet)
1040
1041
###############################################################################
1042
# let's use the TD(lambda) estimator!
1043
loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda)
1044
1045
###############################################################################
1046
# .. note::
1047
# Off-policy usually dictates a TD(0) estimator. Here, we use a TD(:math:`\lambda`)
1048
# estimator, which will introduce some bias as the trajectory that follows
1049
# a certain state has been collected with an outdated policy.
1050
# This trick, as the multi-step trick that can be used during data collection,
1051
# are alternative versions of "hacks" that we usually find to work well in
1052
# practice despite the fact that they introduce some bias in the return
1053
# estimates.
1054
#
1055
# Target network updater
1056
# ~~~~~~~~~~~~~~~~~~~~~~
1057
#
1058
# Target networks are a crucial part of off-policy RL algorithms.
1059
# Updating the target network parameters is made easy thanks to the
1060
# :class:`~torchrl.objectives.HardUpdate` and :class:`~torchrl.objectives.SoftUpdate`
1061
# classes. They're built with the loss module as argument, and the update is
1062
# achieved via a call to `updater.step()` at the appropriate location in the
1063
# training loop.
1064
1065
from torchrl.objectives.utils import SoftUpdate
1066
1067
target_net_updater = SoftUpdate(loss_module, eps=1 - tau)
1068
1069
###############################################################################
1070
# Optimizer
1071
# ~~~~~~~~~
1072
#
1073
# Finally, we will use the Adam optimizer for the policy and value network:
1074
1075
from torch import optim
1076
1077
optimizer_actor = optim.Adam(
1078
loss_module.actor_network_params.values(True, True), lr=1e-4, weight_decay=0.0
1079
)
1080
optimizer_value = optim.Adam(
1081
loss_module.value_network_params.values(True, True), lr=1e-3, weight_decay=1e-2
1082
)
1083
total_collection_steps = total_frames // frames_per_batch
1084
1085
###############################################################################
1086
# Time to train the policy
1087
# ------------------------
1088
#
1089
# The training loop is pretty straightforward now that we have built all the
1090
# modules we need.
1091
#
1092
1093
rewards = []
1094
rewards_eval = []
1095
1096
# Main loop
1097
1098
collected_frames = 0
1099
pbar = tqdm.tqdm(total=total_frames)
1100
r0 = None
1101
for i, tensordict in enumerate(collector):
1102
1103
# update weights of the inference policy
1104
collector.update_policy_weights_()
1105
1106
if r0 is None:
1107
r0 = tensordict["next", "reward"].mean().item()
1108
pbar.update(tensordict.numel())
1109
1110
# extend the replay buffer with the new data
1111
current_frames = tensordict.numel()
1112
collected_frames += current_frames
1113
replay_buffer.extend(tensordict.cpu())
1114
1115
# optimization steps
1116
if collected_frames >= init_random_frames:
1117
for _ in range(update_to_data):
1118
# sample from replay buffer
1119
sampled_tensordict = replay_buffer.sample().to(device)
1120
1121
# Compute loss
1122
loss_dict = loss_module(sampled_tensordict)
1123
1124
# optimize
1125
loss_dict["loss_actor"].backward()
1126
gn1 = torch.nn.utils.clip_grad_norm_(
1127
loss_module.actor_network_params.values(True, True), 10.0
1128
)
1129
optimizer_actor.step()
1130
optimizer_actor.zero_grad()
1131
1132
loss_dict["loss_value"].backward()
1133
gn2 = torch.nn.utils.clip_grad_norm_(
1134
loss_module.value_network_params.values(True, True), 10.0
1135
)
1136
optimizer_value.step()
1137
optimizer_value.zero_grad()
1138
1139
gn = (gn1**2 + gn2**2) ** 0.5
1140
1141
# update priority
1142
if prb:
1143
replay_buffer.update_tensordict_priority(sampled_tensordict)
1144
# update target network
1145
target_net_updater.step()
1146
1147
rewards.append(
1148
(
1149
i,
1150
tensordict["next", "reward"].mean().item(),
1151
)
1152
)
1153
td_record = recorder(None)
1154
if td_record is not None:
1155
rewards_eval.append((i, td_record["r_evaluation"].item()))
1156
if len(rewards_eval) and collected_frames >= init_random_frames:
1157
target_value = loss_dict["target_value"].item()
1158
loss_value = loss_dict["loss_value"].item()
1159
loss_actor = loss_dict["loss_actor"].item()
1160
rn = sampled_tensordict["next", "reward"].mean().item()
1161
rs = sampled_tensordict["next", "reward"].std().item()
1162
pbar.set_description(
1163
f"reward: {rewards[-1][1]: 4.2f} (r0 = {r0: 4.2f}), "
1164
f"reward eval: reward: {rewards_eval[-1][1]: 4.2f}, "
1165
f"reward normalized={rn :4.2f}/{rs :4.2f}, "
1166
f"grad norm={gn: 4.2f}, "
1167
f"loss_value={loss_value: 4.2f}, "
1168
f"loss_actor={loss_actor: 4.2f}, "
1169
f"target value: {target_value: 4.2f}"
1170
)
1171
1172
# update the exploration strategy
1173
actor_model_explore[1].step(current_frames)
1174
1175
collector.shutdown()
1176
del collector
1177
1178
###############################################################################
1179
# Experiment results
1180
# ------------------
1181
#
1182
# We make a simple plot of the average rewards during training. We can observe
1183
# that our policy learned quite well to solve the task.
1184
#
1185
# .. note::
1186
# As already mentioned above, to get a more reasonable performance,
1187
# use a greater value for ``total_frames`` for example, 1M.
1188
1189
from matplotlib import pyplot as plt
1190
1191
plt.figure()
1192
plt.plot(*zip(*rewards), label="training")
1193
plt.plot(*zip(*rewards_eval), label="eval")
1194
plt.legend()
1195
plt.xlabel("iter")
1196
plt.ylabel("reward")
1197
plt.tight_layout()
1198
1199
###############################################################################
1200
# Conclusion
1201
# ----------
1202
#
1203
# In this tutorial, we have learned how to code a loss module in TorchRL given
1204
# the concrete example of DDPG.
1205
#
1206
# The key takeaways are:
1207
#
1208
# - How to use the :class:`~torchrl.objectives.LossModule` class to code up a new
1209
# loss component;
1210
# - How to use (or not) a target network, and how to update its parameters;
1211
# - How to create an optimizer associated with a loss module.
1212
#
1213
# Next Steps
1214
# ----------
1215
#
1216
# To iterate further on this loss module we might consider:
1217
#
1218
# - Using `@dispatch` (see `[Feature] Distpatch IQL loss module <https://github.com/pytorch/rl/pull/1230>`_.)
1219
# - Allowing flexible TensorDict keys.
1220
#
1221
1222