CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/reinforcement_ppo.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Reinforcement Learning (PPO) with TorchRL Tutorial
4
==================================================
5
**Author**: `Vincent Moens <https://github.com/vmoens>`_
6
7
This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy
8
network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium
9
control library <https://github.com/Farama-Foundation/Gymnasium>`__.
10
11
.. figure:: /_static/img/invpendulum.gif
12
:alt: Inverted pendulum
13
14
Inverted pendulum
15
16
Key learnings:
17
18
- How to create an environment in TorchRL, transform its outputs, and collect data from this environment;
19
- How to make your classes talk to each other using :class:`~tensordict.TensorDict`;
20
- The basics of building your training loop with TorchRL:
21
22
- How to compute the advantage signal for policy gradient methods;
23
- How to create a stochastic policy using a probabilistic neural network;
24
- How to create a dynamic replay buffer and sample from it without repetition.
25
26
We will cover six crucial components of TorchRL:
27
28
* `environments <https://pytorch.org/rl/reference/envs.html>`__
29
* `transforms <https://pytorch.org/rl/reference/envs.html#transforms>`__
30
* `models (policy and value function) <https://pytorch.org/rl/reference/modules.html>`__
31
* `loss modules <https://pytorch.org/rl/reference/objectives.html>`__
32
* `data collectors <https://pytorch.org/rl/reference/collectors.html>`__
33
* `replay buffers <https://pytorch.org/rl/reference/data.html#replay-buffers>`__
34
35
"""
36
37
######################################################################
38
# If you are running this in Google Colab, make sure you install the following dependencies:
39
#
40
# .. code-block:: bash
41
#
42
# !pip3 install torchrl
43
# !pip3 install gym[mujoco]
44
# !pip3 install tqdm
45
#
46
# Proximal Policy Optimization (PPO) is a policy-gradient algorithm where a
47
# batch of data is being collected and directly consumed to train the policy to maximise
48
# the expected return given some proximality constraints. You can think of it
49
# as a sophisticated version of `REINFORCE <https://link.springer.com/content/pdf/10.1007/BF00992696.pdf>`_,
50
# the foundational policy-optimization algorithm. For more information, see the
51
# `Proximal Policy Optimization Algorithms <https://arxiv.org/abs/1707.06347>`_ paper.
52
#
53
# PPO is usually regarded as a fast and efficient method for online, on-policy
54
# reinforcement algorithm. TorchRL provides a loss-module that does all the work
55
# for you, so that you can rely on this implementation and focus on solving your
56
# problem rather than re-inventing the wheel every time you want to train a policy.
57
#
58
# For completeness, here is a brief overview of what the loss computes, even though
59
# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows:
60
# 1. we will sample a batch of data by playing the
61
# policy in the environment for a given number of steps.
62
# 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using
63
# a clipped version of the REINFORCE loss.
64
# 3. The clipping will put a pessimistic bound on our loss: lower return estimates will
65
# be favored compared to higher ones.
66
# The precise formula of the loss is:
67
#
68
# .. math::
69
#
70
# L(s,a,\theta_k,\theta) = \min\left(
71
# \frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)} A^{\pi_{\theta_k}}(s,a), \;\;
72
# g(\epsilon, A^{\pi_{\theta_k}}(s,a))
73
# \right),
74
#
75
# There are two components in that loss: in the first part of the minimum operator,
76
# we simply compute an importance-weighted version of the REINFORCE loss (for example, a
77
# REINFORCE loss that we have corrected for the fact that the current policy
78
# configuration lags the one that was used for the data collection).
79
# The second part of that minimum operator is a similar loss where we have clipped
80
# the ratios when they exceeded or were below a given pair of thresholds.
81
#
82
# This loss ensures that whether the advantage is positive or negative, policy
83
# updates that would produce significant shifts from the previous configuration
84
# are being discouraged.
85
#
86
# This tutorial is structured as follows:
87
#
88
# 1. First, we will define a set of hyperparameters we will be using for training.
89
#
90
# 2. Next, we will focus on creating our environment, or simulator, using TorchRL's
91
# wrappers and transforms.
92
#
93
# 3. Next, we will design the policy network and the value model,
94
# which is indispensable to the loss function. These modules will be used
95
# to configure our loss module.
96
#
97
# 4. Next, we will create the replay buffer and data loader.
98
#
99
# 5. Finally, we will run our training loop and analyze the results.
100
#
101
# Throughout this tutorial, we'll be using the :mod:`tensordict` library.
102
# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract
103
# what a module reads and writes and care less about the specific data
104
# description and more about the algorithm itself.
105
#
106
107
import warnings
108
warnings.filterwarnings("ignore")
109
from torch import multiprocessing
110
111
# sphinx_gallery_start_ignore
112
113
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
114
# `__main__` method call, but for the easy of reading the code switch to fork
115
# which is also a default spawn method in Google's Colaboratory
116
try:
117
multiprocessing.set_start_method("fork")
118
except RuntimeError:
119
pass
120
121
# sphinx_gallery_end_ignore
122
123
from collections import defaultdict
124
125
import matplotlib.pyplot as plt
126
import torch
127
from tensordict.nn import TensorDictModule
128
from tensordict.nn.distributions import NormalParamExtractor
129
from torch import nn
130
from torchrl.collectors import SyncDataCollector
131
from torchrl.data.replay_buffers import ReplayBuffer
132
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
133
from torchrl.data.replay_buffers.storages import LazyTensorStorage
134
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
135
TransformedEnv)
136
from torchrl.envs.libs.gym import GymEnv
137
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
138
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
139
from torchrl.objectives import ClipPPOLoss
140
from torchrl.objectives.value import GAE
141
from tqdm import tqdm
142
143
######################################################################
144
# Define Hyperparameters
145
# ----------------------
146
#
147
# We set the hyperparameters for our algorithm. Depending on the resources
148
# available, one may choose to execute the policy on GPU or on another
149
# device.
150
# The ``frame_skip`` will control how for how many frames is a single
151
# action being executed. The rest of the arguments that count frames
152
# must be corrected for this value (since one environment step will
153
# actually return ``frame_skip`` frames).
154
#
155
156
is_fork = multiprocessing.get_start_method() == "fork"
157
device = (
158
torch.device(0)
159
if torch.cuda.is_available() and not is_fork
160
else torch.device("cpu")
161
)
162
num_cells = 256 # number of cells in each layer i.e. output dim.
163
lr = 3e-4
164
max_grad_norm = 1.0
165
166
######################################################################
167
# Data collection parameters
168
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
169
#
170
# When collecting data, we will be able to choose how big each batch will be
171
# by defining a ``frames_per_batch`` parameter. We will also define how many
172
# frames (such as the number of interactions with the simulator) we will allow ourselves to
173
# use. In general, the goal of an RL algorithm is to learn to solve the task
174
# as fast as it can in terms of environment interactions: the lower the ``total_frames``
175
# the better.
176
#
177
frames_per_batch = 1000
178
# For a complete training, bring the number of frames up to 1M
179
total_frames = 50_000
180
181
######################################################################
182
# PPO parameters
183
# ~~~~~~~~~~~~~~
184
#
185
# At each data collection (or batch collection) we will run the optimization
186
# over a certain number of *epochs*, each time consuming the entire data we just
187
# acquired in a nested training loop. Here, the ``sub_batch_size`` is different from the
188
# ``frames_per_batch`` here above: recall that we are working with a "batch of data"
189
# coming from our collector, which size is defined by ``frames_per_batch``, and that
190
# we will further split in smaller sub-batches during the inner training loop.
191
# The size of these sub-batches is controlled by ``sub_batch_size``.
192
#
193
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
194
num_epochs = 10 # optimization steps per batch of data collected
195
clip_epsilon = (
196
0.2 # clip value for PPO loss: see the equation in the intro for more context.
197
)
198
gamma = 0.99
199
lmbda = 0.95
200
entropy_eps = 1e-4
201
202
######################################################################
203
# Define an environment
204
# ---------------------
205
#
206
# In RL, an *environment* is usually the way we refer to a simulator or a
207
# control system. Various libraries provide simulation environments for reinforcement
208
# learning, including Gymnasium (previously OpenAI Gym), DeepMind control suite, and
209
# many others.
210
# As a general library, TorchRL's goal is to provide an interchangeable interface
211
# to a large panel of RL simulators, allowing you to easily swap one environment
212
# with another. For example, creating a wrapped gym environment can be achieved with few characters:
213
#
214
215
base_env = GymEnv("InvertedDoublePendulum-v4", device=device)
216
217
######################################################################
218
# There are a few things to notice in this code: first, we created
219
# the environment by calling the ``GymEnv`` wrapper. If extra keyword arguments
220
# are passed, they will be transmitted to the ``gym.make`` method, hence covering
221
# the most common environment construction commands.
222
# Alternatively, one could also directly create a gym environment using ``gym.make(env_name, **kwargs)``
223
# and wrap it in a `GymWrapper` class.
224
#
225
# Also the ``device`` argument: for gym, this only controls the device where
226
# input action and observed states will be stored, but the execution will always
227
# be done on CPU. The reason for this is simply that gym does not support on-device
228
# execution, unless specified otherwise. For other libraries, we have control over
229
# the execution device and, as much as we can, we try to stay consistent in terms of
230
# storing and execution backends.
231
#
232
# Transforms
233
# ~~~~~~~~~~
234
#
235
# We will append some transforms to our environments to prepare the data for
236
# the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different
237
# approach, more similar to other pytorch domain libraries, through the use of transforms.
238
# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv`
239
# instance and append the sequence of transforms to it. The transformed environment will inherit
240
# the device and meta-data of the wrapped environment, and transform these depending on the sequence
241
# of transforms it contains.
242
#
243
# Normalization
244
# ~~~~~~~~~~~~~
245
#
246
# The first to encode is a normalization transform.
247
# As a rule of thumbs, it is preferable to have data that loosely
248
# match a unit Gaussian distribution: to obtain this, we will
249
# run a certain number of random steps in the environment and compute
250
# the summary statistics of these observations.
251
#
252
# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will
253
# convert double entries to single-precision numbers, ready to be read by the
254
# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before
255
# the environment is terminated. We will use this measure as a supplementary measure
256
# of performance.
257
#
258
# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict`
259
# to communicate. You could think of it as a python dictionary with some extra
260
# tensor features. In practice, this means that many modules we will be working
261
# with need to be told what key to read (``in_keys``) and what key to write
262
# (``out_keys``) in the ``tensordict`` they will receive. Usually, if ``out_keys``
263
# is omitted, it is assumed that the ``in_keys`` entries will be updated
264
# in-place. For our transforms, the only entry we are interested in is referred
265
# to as ``"observation"`` and our transform layers will be told to modify this
266
# entry and this entry only:
267
#
268
269
env = TransformedEnv(
270
base_env,
271
Compose(
272
# normalize observations
273
ObservationNorm(in_keys=["observation"]),
274
DoubleToFloat(),
275
StepCounter(),
276
),
277
)
278
279
######################################################################
280
# As you may have noticed, we have created a normalization layer but we did not
281
# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can
282
# automatically gather the summary statistics of our environment:
283
#
284
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
285
286
######################################################################
287
# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a
288
# location and a scale that will be used to normalize the data.
289
#
290
# Let us do a little sanity check for the shape of our summary stats:
291
#
292
print("normalization constant shape:", env.transform[0].loc.shape)
293
294
######################################################################
295
# An environment is not only defined by its simulator and transforms, but also
296
# by a series of metadata that describe what can be expected during its
297
# execution.
298
# For efficiency purposes, TorchRL is quite stringent when it comes to
299
# environment specs, but you can easily check that your environment specs are
300
# adequate.
301
# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and
302
# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits
303
# from it already take care of setting the proper specs for your environment so
304
# you should not have to care about this.
305
#
306
# Nevertheless, let's see a concrete example using our transformed
307
# environment by looking at its specs.
308
# There are three specs to look at: ``observation_spec`` which defines what
309
# is to be expected when executing an action in the environment,
310
# ``reward_spec`` which indicates the reward domain and finally the
311
# ``input_spec`` (which contains the ``action_spec``) and which represents
312
# everything an environment requires to execute a single step.
313
#
314
print("observation_spec:", env.observation_spec)
315
print("reward_spec:", env.reward_spec)
316
print("input_spec:", env.input_spec)
317
print("action_spec (as defined by input_spec):", env.action_spec)
318
319
######################################################################
320
# the :func:`check_env_specs` function runs a small rollout and compares its output against the environment
321
# specs. If no error is raised, we can be confident that the specs are properly defined:
322
#
323
check_env_specs(env)
324
325
######################################################################
326
# For fun, let's see what a simple random rollout looks like. You can
327
# call `env.rollout(n_steps)` and get an overview of what the environment inputs
328
# and outputs look like. Actions will automatically be drawn from the action spec
329
# domain, so you don't need to care about designing a random sampler.
330
#
331
# Typically, at each step, an RL environment receives an
332
# action as input, and outputs an observation, a reward and a done state. The
333
# observation may be composite, meaning that it could be composed of more than one
334
# tensor. This is not a problem for TorchRL, since the whole set of observations
335
# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout
336
# (for example, a sequence of environment steps and random action generations) over a given
337
# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape
338
# that matches this trajectory length:
339
#
340
rollout = env.rollout(3)
341
print("rollout of three steps:", rollout)
342
print("Shape of the rollout TensorDict:", rollout.batch_size)
343
344
######################################################################
345
# Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps
346
# we ran it for. The ``"next"`` entry points to the data coming after the current step.
347
# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this
348
# may not be the case if we are using some specific transformations (for example, multi-step).
349
#
350
# Policy
351
# ------
352
#
353
# PPO utilizes a stochastic policy to handle exploration. This means that our
354
# neural network will have to output the parameters of a distribution, rather
355
# than a single value corresponding to the action taken.
356
#
357
# As the data is continuous, we use a Tanh-Normal distribution to respect the
358
# action space boundaries. TorchRL provides such distribution, and the only
359
# thing we need to care about is to build a neural network that outputs the
360
# right number of parameters for the policy to work with (a location, or mean,
361
# and a scale):
362
#
363
# .. math::
364
#
365
# f_{\theta}(\text{observation}) = \mu_{\theta}(\text{observation}), \sigma^{+}_{\theta}(\text{observation})
366
#
367
# The only extra-difficulty that is brought up here is to split our output in two
368
# equal parts and map the second to a strictly positive space.
369
#
370
# We design the policy in three steps:
371
#
372
# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``.
373
#
374
# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter).
375
#
376
# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it.
377
#
378
379
actor_net = nn.Sequential(
380
nn.LazyLinear(num_cells, device=device),
381
nn.Tanh(),
382
nn.LazyLinear(num_cells, device=device),
383
nn.Tanh(),
384
nn.LazyLinear(num_cells, device=device),
385
nn.Tanh(),
386
nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
387
NormalParamExtractor(),
388
)
389
390
######################################################################
391
# To enable the policy to "talk" with the environment through the ``tensordict``
392
# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This
393
# class will simply ready the ``in_keys`` it is provided with and write the
394
# outputs in-place at the registered ``out_keys``.
395
#
396
policy_module = TensorDictModule(
397
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
398
)
399
400
######################################################################
401
# We now need to build a distribution out of the location and scale of our
402
# normal distribution. To do so, we instruct the
403
# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`
404
# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale
405
# parameters. We also provide the minimum and maximum values of this
406
# distribution, which we gather from the environment specs.
407
#
408
# The name of the ``in_keys`` (and hence the name of the ``out_keys`` from
409
# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may
410
# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the
411
# ``loc`` and ``scale`` keyword arguments. That being said,
412
# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts
413
# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates
414
# what ``in_key`` string should be used for every keyword argument that is to be used.
415
#
416
policy_module = ProbabilisticActor(
417
module=policy_module,
418
spec=env.action_spec,
419
in_keys=["loc", "scale"],
420
distribution_class=TanhNormal,
421
distribution_kwargs={
422
"min": env.action_spec.space.low,
423
"max": env.action_spec.space.high,
424
},
425
return_log_prob=True,
426
# we'll need the log-prob for the numerator of the importance weights
427
)
428
429
######################################################################
430
# Value network
431
# -------------
432
#
433
# The value network is a crucial component of the PPO algorithm, even though it
434
# won't be used at inference time. This module will read the observations and
435
# return an estimation of the discounted return for the following trajectory.
436
# This allows us to amortize learning by relying on the some utility estimation
437
# that is learned on-the-fly during training. Our value network share the same
438
# structure as the policy, but for simplicity we assign it its own set of
439
# parameters.
440
#
441
value_net = nn.Sequential(
442
nn.LazyLinear(num_cells, device=device),
443
nn.Tanh(),
444
nn.LazyLinear(num_cells, device=device),
445
nn.Tanh(),
446
nn.LazyLinear(num_cells, device=device),
447
nn.Tanh(),
448
nn.LazyLinear(1, device=device),
449
)
450
451
value_module = ValueOperator(
452
module=value_net,
453
in_keys=["observation"],
454
)
455
456
######################################################################
457
# let's try our policy and value modules. As we said earlier, the usage of
458
# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output
459
# of the environment to run these modules, as they know what information to read
460
# and where to write it:
461
#
462
print("Running policy:", policy_module(env.reset()))
463
print("Running value:", value_module(env.reset()))
464
465
######################################################################
466
# Data collector
467
# --------------
468
#
469
# TorchRL provides a set of `DataCollector classes <https://pytorch.org/rl/reference/collectors.html>`__.
470
# Briefly, these classes execute three operations: reset an environment,
471
# compute an action given the latest observation, execute a step in the environment,
472
# and repeat the last two steps until the environment signals a stop (or reaches
473
# a done state).
474
#
475
# They allow you to control how many frames to collect at each iteration
476
# (through the ``frames_per_batch`` parameter),
477
# when to reset the environment (through the ``max_frames_per_traj`` argument),
478
# on which ``device`` the policy should be executed, etc. They are also
479
# designed to work efficiently with batched and multiprocessed environments.
480
#
481
# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`:
482
# it is an iterator that you can use to get batches of data of a given length, and
483
# that will stop once a total number of frames (``total_frames``) have been
484
# collected.
485
# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and
486
# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute
487
# the same operations in synchronous and asynchronous manner over a
488
# set of multiprocessed workers.
489
#
490
# As for the policy and environment before, the data collector will return
491
# :class:`~tensordict.TensorDict` instances with a total number of elements that will
492
# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the
493
# training loop allows you to write data loading pipelines
494
# that are 100% oblivious to the actual specificities of the rollout content.
495
#
496
collector = SyncDataCollector(
497
env,
498
policy_module,
499
frames_per_batch=frames_per_batch,
500
total_frames=total_frames,
501
split_trajs=False,
502
device=device,
503
)
504
505
######################################################################
506
# Replay buffer
507
# -------------
508
#
509
# Replay buffers are a common building piece of off-policy RL algorithms.
510
# In on-policy contexts, a replay buffer is refilled every time a batch of
511
# data is collected, and its data is repeatedly consumed for a certain number
512
# of epochs.
513
#
514
# TorchRL's replay buffers are built using a common container
515
# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components
516
# of the buffer: a storage, a writer, a sampler and possibly some transforms.
517
# Only the storage (which indicates the replay buffer capacity) is mandatory.
518
# We also specify a sampler without repetition to avoid sampling multiple times
519
# the same item in one epoch.
520
# Using a replay buffer for PPO is not mandatory and we could simply
521
# sample the sub-batches from the collected batch, but using these classes
522
# make it easy for us to build the inner training loop in a reproducible way.
523
#
524
525
replay_buffer = ReplayBuffer(
526
storage=LazyTensorStorage(max_size=frames_per_batch),
527
sampler=SamplerWithoutReplacement(),
528
)
529
530
######################################################################
531
# Loss function
532
# -------------
533
#
534
# The PPO loss can be directly imported from TorchRL for convenience using the
535
# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO:
536
# it hides away the mathematical operations of PPO and the control flow that
537
# goes with it.
538
#
539
# PPO requires some "advantage estimation" to be computed. In short, an advantage
540
# is a value that reflects an expectancy over the return value while dealing with
541
# the bias / variance tradeoff.
542
# To compute the advantage, one just needs to (1) build the advantage module, which
543
# utilizes our value operator, and (2) pass each batch of data through it before each
544
# epoch.
545
# The GAE module will update the input ``tensordict`` with new ``"advantage"`` and
546
# ``"value_target"`` entries.
547
# The ``"value_target"`` is a gradient-free tensor that represents the empirical
548
# value that the value network should represent with the input observation.
549
# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to
550
# return the policy and value losses.
551
#
552
553
advantage_module = GAE(
554
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
555
)
556
557
loss_module = ClipPPOLoss(
558
actor_network=policy_module,
559
critic_network=value_module,
560
clip_epsilon=clip_epsilon,
561
entropy_bonus=bool(entropy_eps),
562
entropy_coef=entropy_eps,
563
# these keys match by default but we set this for completeness
564
critic_coef=1.0,
565
loss_critic_type="smooth_l1",
566
)
567
568
optim = torch.optim.Adam(loss_module.parameters(), lr)
569
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
570
optim, total_frames // frames_per_batch, 0.0
571
)
572
573
######################################################################
574
# Training loop
575
# -------------
576
# We now have all the pieces needed to code our training loop.
577
# The steps include:
578
#
579
# * Collect data
580
#
581
# * Compute advantage
582
#
583
# * Loop over the collected to compute loss values
584
# * Back propagate
585
# * Optimize
586
# * Repeat
587
#
588
# * Repeat
589
#
590
# * Repeat
591
#
592
593
594
logs = defaultdict(list)
595
pbar = tqdm(total=total_frames)
596
eval_str = ""
597
598
# We iterate over the collector until it reaches the total number of frames it was
599
# designed to collect:
600
for i, tensordict_data in enumerate(collector):
601
# we now have a batch of data to work with. Let's learn something from it.
602
for _ in range(num_epochs):
603
# We'll need an "advantage" signal to make PPO work.
604
# We re-compute it at each epoch as its value depends on the value
605
# network which is updated in the inner loop.
606
advantage_module(tensordict_data)
607
data_view = tensordict_data.reshape(-1)
608
replay_buffer.extend(data_view.cpu())
609
for _ in range(frames_per_batch // sub_batch_size):
610
subdata = replay_buffer.sample(sub_batch_size)
611
loss_vals = loss_module(subdata.to(device))
612
loss_value = (
613
loss_vals["loss_objective"]
614
+ loss_vals["loss_critic"]
615
+ loss_vals["loss_entropy"]
616
)
617
618
# Optimization: backward, grad clipping and optimization step
619
loss_value.backward()
620
# this is not strictly mandatory but it's good practice to keep
621
# your gradient norm bounded
622
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
623
optim.step()
624
optim.zero_grad()
625
626
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
627
pbar.update(tensordict_data.numel())
628
cum_reward_str = (
629
f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
630
)
631
logs["step_count"].append(tensordict_data["step_count"].max().item())
632
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
633
logs["lr"].append(optim.param_groups[0]["lr"])
634
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
635
if i % 10 == 0:
636
# We evaluate the policy once every 10 batches of data.
637
# Evaluation is rather simple: execute the policy without exploration
638
# (take the expected value of the action distribution) for a given
639
# number of steps (1000, which is our ``env`` horizon).
640
# The ``rollout`` method of the ``env`` can take a policy as argument:
641
# it will then execute this policy at each step.
642
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
643
# execute a rollout with the trained policy
644
eval_rollout = env.rollout(1000, policy_module)
645
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
646
logs["eval reward (sum)"].append(
647
eval_rollout["next", "reward"].sum().item()
648
)
649
logs["eval step_count"].append(eval_rollout["step_count"].max().item())
650
eval_str = (
651
f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
652
f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
653
f"eval step-count: {logs['eval step_count'][-1]}"
654
)
655
del eval_rollout
656
pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))
657
658
# We're also using a learning rate scheduler. Like the gradient clipping,
659
# this is a nice-to-have but nothing necessary for PPO to work.
660
scheduler.step()
661
662
######################################################################
663
# Results
664
# -------
665
#
666
# Before the 1M step cap is reached, the algorithm should have reached a max
667
# step count of 1000 steps, which is the maximum number of steps before the
668
# trajectory is truncated.
669
#
670
plt.figure(figsize=(10, 10))
671
plt.subplot(2, 2, 1)
672
plt.plot(logs["reward"])
673
plt.title("training rewards (average)")
674
plt.subplot(2, 2, 2)
675
plt.plot(logs["step_count"])
676
plt.title("Max step count (training)")
677
plt.subplot(2, 2, 3)
678
plt.plot(logs["eval reward (sum)"])
679
plt.title("Return (test)")
680
plt.subplot(2, 2, 4)
681
plt.plot(logs["eval step_count"])
682
plt.title("Max step count (test)")
683
plt.show()
684
685
######################################################################
686
# Conclusion and next steps
687
# -------------------------
688
#
689
# In this tutorial, we have learned:
690
#
691
# 1. How to create and customize an environment with :py:mod:`torchrl`;
692
# 2. How to write a model and a loss function;
693
# 3. How to set up a typical training loop.
694
#
695
# If you want to experiment with this tutorial a bit more, you can apply the following modifications:
696
#
697
# * From an efficiency perspective,
698
# we could run several simulations in parallel to speed up data collection.
699
# Check :class:`~torchrl.envs.ParallelEnv` for further information.
700
#
701
# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to
702
# the environment after asking for rendering to get a visual rendering of the
703
# inverted pendulum in action. Check :py:mod:`torchrl.record` to
704
# know more.
705
#
706
707