CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/pendulum.py
Views: 712
1
# -*- coding: utf-8 -*-
2
3
"""
4
Pendulum: Writing your environment and transforms with TorchRL
5
==============================================================
6
7
**Author**: `Vincent Moens <https://github.com/vmoens>`_
8
9
Creating an environment (a simulator or an interface to a physical control system)
10
is an integrative part of reinforcement learning and control engineering.
11
12
TorchRL provides a set of tools to do this in multiple contexts.
13
This tutorial demonstrates how to use PyTorch and TorchRL code a pendulum
14
simulator from the ground up.
15
It is freely inspired by the Pendulum-v1 implementation from `OpenAI-Gym/Farama-Gymnasium
16
control library <https://github.com/Farama-Foundation/Gymnasium>`__.
17
18
.. figure:: /_static/img/pendulum.gif
19
:alt: Pendulum
20
:align: center
21
22
Simple Pendulum
23
24
Key learnings:
25
26
- How to design an environment in TorchRL:
27
- Writing specs (input, observation and reward);
28
- Implementing behavior: seeding, reset and step.
29
- Transforming your environment inputs and outputs, and writing your own
30
transforms;
31
- How to use :class:`~tensordict.TensorDict` to carry arbitrary data structures
32
through the ``codebase``.
33
34
In the process, we will touch three crucial components of TorchRL:
35
36
* `environments <https://pytorch.org/rl/reference/envs.html>`__
37
* `transforms <https://pytorch.org/rl/reference/envs.html#transforms>`__
38
* `models (policy and value function) <https://pytorch.org/rl/reference/modules.html>`__
39
40
"""
41
42
######################################################################
43
# To give a sense of what can be achieved with TorchRL's environments, we will
44
# be designing a *stateless* environment. While stateful environments keep track of
45
# the latest physical state encountered and rely on this to simulate the state-to-state
46
# transition, stateless environments expect the current state to be provided to
47
# them at each step, along with the action undertaken. TorchRL supports both
48
# types of environments, but stateless environments are more generic and hence
49
# cover a broader range of features of the environment API in TorchRL.
50
#
51
# Modeling stateless environments gives users full control over the input and
52
# outputs of the simulator: one can reset an experiment at any stage or actively
53
# modify the dynamics from the outside. However, it assumes that we have some control
54
# over a task, which may not always be the case: solving a problem where we cannot
55
# control the current state is more challenging but has a much wider set of applications.
56
#
57
# Another advantage of stateless environments is that they can enable
58
# batched execution of transition simulations. If the backend and the
59
# implementation allow it, an algebraic operation can be executed seamlessly on
60
# scalars, vectors, or tensors. This tutorial gives such examples.
61
#
62
# This tutorial will be structured as follows:
63
#
64
# * We will first get acquainted with the environment properties:
65
# its shape (``batch_size``), its methods (mainly :meth:`~torchrl.envs.EnvBase.step`,
66
# :meth:`~torchrl.envs.EnvBase.reset` and :meth:`~torchrl.envs.EnvBase.set_seed`)
67
# and finally its specs.
68
# * After having coded our simulator, we will demonstrate how it can be used
69
# during training with transforms.
70
# * We will explore new avenues that follow from the TorchRL's API,
71
# including: the possibility of transforming inputs, the vectorized execution
72
# of the simulation and the possibility of backpropagation through the
73
# simulation graph.
74
# * Finally, we will train a simple policy to solve the system we implemented.
75
#
76
77
# sphinx_gallery_start_ignore
78
import warnings
79
80
warnings.filterwarnings("ignore")
81
from torch import multiprocessing
82
83
# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
84
# `__main__` method call, but for the easy of reading the code switch to fork
85
# which is also a default spawn method in Google's Colaboratory
86
try:
87
multiprocessing.set_start_method("fork")
88
except RuntimeError:
89
pass
90
91
# sphinx_gallery_end_ignore
92
93
from collections import defaultdict
94
from typing import Optional
95
96
import numpy as np
97
import torch
98
import tqdm
99
from tensordict import TensorDict, TensorDictBase
100
from tensordict.nn import TensorDictModule
101
from torch import nn
102
103
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
104
from torchrl.envs import (
105
CatTensors,
106
EnvBase,
107
Transform,
108
TransformedEnv,
109
UnsqueezeTransform,
110
)
111
from torchrl.envs.transforms.transforms import _apply_to_composite
112
from torchrl.envs.utils import check_env_specs, step_mdp
113
114
DEFAULT_X = np.pi
115
DEFAULT_Y = 1.0
116
117
######################################################################
118
# There are four things you must take care of when designing a new environment
119
# class:
120
#
121
# * :meth:`EnvBase._reset`, which codes for the resetting of the simulator
122
# at a (potentially random) initial state;
123
# * :meth:`EnvBase._step` which codes for the state transition dynamic;
124
# * :meth:`EnvBase._set_seed`` which implements the seeding mechanism;
125
# * the environment specs.
126
#
127
# Let us first describe the problem at hand: we would like to model a simple
128
# pendulum over which we can control the torque applied on its fixed point.
129
# Our goal is to place the pendulum in upward position (angular position at 0
130
# by convention) and having it standing still in that position.
131
# To design our dynamic system, we need to define two equations: the motion
132
# equation following an action (the torque applied) and the reward equation
133
# that will constitute our objective function.
134
#
135
# For the motion equation, we will update the angular velocity following:
136
#
137
# .. math::
138
#
139
# \dot{\theta}_{t+1} = \dot{\theta}_t + (3 * g / (2 * L) * \sin(\theta_t) + 3 / (m * L^2) * u) * dt
140
#
141
# where :math:`\dot{\theta}` is the angular velocity in rad/sec, :math:`g` is the
142
# gravitational force, :math:`L` is the pendulum length, :math:`m` is its mass,
143
# :math:`\theta` is its angular position and :math:`u` is the torque. The
144
# angular position is then updated according to
145
#
146
# .. math::
147
#
148
# \theta_{t+1} = \theta_{t} + \dot{\theta}_{t+1} dt
149
#
150
# We define our reward as
151
#
152
# .. math::
153
#
154
# r = -(\theta^2 + 0.1 * \dot{\theta}^2 + 0.001 * u^2)
155
#
156
# which will be maximized when the angle is close to 0 (pendulum in upward
157
# position), the angular velocity is close to 0 (no motion) and the torque is
158
# 0 too.
159
#
160
# Coding the effect of an action: :func:`~torchrl.envs.EnvBase._step`
161
# -------------------------------------------------------------------
162
#
163
# The step method is the first thing to consider, as it will encode
164
# the simulation that is of interest to us. In TorchRL, the
165
# :class:`~torchrl.envs.EnvBase` class has a :meth:`EnvBase.step`
166
# method that receives a :class:`tensordict.TensorDict`
167
# instance with an ``"action"`` entry indicating what action is to be taken.
168
#
169
# To facilitate the reading and writing from that ``tensordict`` and to make sure
170
# that the keys are consistent with what's expected from the library, the
171
# simulation part has been delegated to a private abstract method :meth:`_step`
172
# which reads input data from a ``tensordict``, and writes a *new* ``tensordict``
173
# with the output data.
174
#
175
# The :func:`_step` method should do the following:
176
#
177
# 1. Read the input keys (such as ``"action"``) and execute the simulation
178
# based on these;
179
# 2. Retrieve observations, done state and reward;
180
# 3. Write the set of observation values along with the reward and done state
181
# at the corresponding entries in a new :class:`TensorDict`.
182
#
183
# Next, the :meth:`~torchrl.envs.EnvBase.step` method will merge the output
184
# of :meth:`~torchrl.envs.EnvBase.step` in the input ``tensordict`` to enforce
185
# input/output consistency.
186
#
187
# Typically, for stateful environments, this will look like this:
188
#
189
# .. code-block::
190
#
191
# >>> policy(env.reset())
192
# >>> print(tensordict)
193
# TensorDict(
194
# fields={
195
# action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
196
# done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
197
# observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
198
# batch_size=torch.Size([]),
199
# device=cpu,
200
# is_shared=False)
201
# >>> env.step(tensordict)
202
# >>> print(tensordict)
203
# TensorDict(
204
# fields={
205
# action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
206
# done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
207
# next: TensorDict(
208
# fields={
209
# done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
210
# observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
211
# reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
212
# batch_size=torch.Size([]),
213
# device=cpu,
214
# is_shared=False),
215
# observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
216
# batch_size=torch.Size([]),
217
# device=cpu,
218
# is_shared=False)
219
#
220
# Notice that the root ``tensordict`` has not changed, the only modification is the
221
# appearance of a new ``"next"`` entry that contains the new information.
222
#
223
# In the Pendulum example, our :meth:`_step` method will read the relevant
224
# entries from the input ``tensordict`` and compute the position and velocity of
225
# the pendulum after the force encoded by the ``"action"`` key has been applied
226
# onto it. We compute the new angular position of the pendulum
227
# ``"new_th"`` as the result of the previous position ``"th"`` plus the new
228
# velocity ``"new_thdot"`` over a time interval ``dt``.
229
#
230
# Since our goal is to turn the pendulum up and maintain it still in that
231
# position, our ``cost`` (negative reward) function is lower for positions
232
# close to the target and low speeds.
233
# Indeed, we want to discourage positions that are far from being "upward"
234
# and/or speeds that are far from 0.
235
#
236
# In our example, :meth:`EnvBase._step` is encoded as a static method since our
237
# environment is stateless. In stateful settings, the ``self`` argument is
238
# needed as the state needs to be read from the environment.
239
#
240
241
242
def _step(tensordict):
243
th, thdot = tensordict["th"], tensordict["thdot"] # th := theta
244
245
g_force = tensordict["params", "g"]
246
mass = tensordict["params", "m"]
247
length = tensordict["params", "l"]
248
dt = tensordict["params", "dt"]
249
u = tensordict["action"].squeeze(-1)
250
u = u.clamp(-tensordict["params", "max_torque"], tensordict["params", "max_torque"])
251
costs = angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)
252
253
new_thdot = (
254
thdot
255
+ (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u) * dt
256
)
257
new_thdot = new_thdot.clamp(
258
-tensordict["params", "max_speed"], tensordict["params", "max_speed"]
259
)
260
new_th = th + new_thdot * dt
261
reward = -costs.view(*tensordict.shape, 1)
262
done = torch.zeros_like(reward, dtype=torch.bool)
263
out = TensorDict(
264
{
265
"th": new_th,
266
"thdot": new_thdot,
267
"params": tensordict["params"],
268
"reward": reward,
269
"done": done,
270
},
271
tensordict.shape,
272
)
273
return out
274
275
276
def angle_normalize(x):
277
return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
278
279
280
######################################################################
281
# Resetting the simulator: :func:`~torchrl.envs.EnvBase._reset`
282
# -------------------------------------------------------------
283
#
284
# The second method we need to care about is the
285
# :meth:`~torchrl.envs.EnvBase._reset` method. Like
286
# :meth:`~torchrl.envs.EnvBase._step`, it should write the observation entries
287
# and possibly a done state in the ``tensordict`` it outputs (if the done state is
288
# omitted, it will be filled as ``False`` by the parent method
289
# :meth:`~torchrl.envs.EnvBase.reset`). In some contexts, it is required that
290
# the ``_reset`` method receives a command from the function that called
291
# it (for example, in multi-agent settings we may want to indicate which agents need
292
# to be reset). This is why the :meth:`~torchrl.envs.EnvBase._reset` method
293
# also expects a ``tensordict`` as input, albeit it may perfectly be empty or
294
# ``None``.
295
#
296
# The parent :meth:`EnvBase.reset` does some simple checks like the
297
# :meth:`EnvBase.step` does, such as making sure that a ``"done"`` state
298
# is returned in the output ``tensordict`` and that the shapes match what is
299
# expected from the specs.
300
#
301
# For us, the only important thing to consider is whether
302
# :meth:`EnvBase._reset` contains all the expected observations. Once more,
303
# since we are working with a stateless environment, we pass the configuration
304
# of the pendulum in a nested ``tensordict`` named ``"params"``.
305
#
306
# In this example, we do not pass a done state as this is not mandatory
307
# for :meth:`_reset` and our environment is non-terminating, so we always
308
# expect it to be ``False``.
309
#
310
311
312
def _reset(self, tensordict):
313
if tensordict is None or tensordict.is_empty():
314
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
315
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
316
# parameters to get started.
317
tensordict = self.gen_params(batch_size=self.batch_size)
318
319
high_th = torch.tensor(DEFAULT_X, device=self.device)
320
high_thdot = torch.tensor(DEFAULT_Y, device=self.device)
321
low_th = -high_th
322
low_thdot = -high_thdot
323
324
# for non batch-locked environments, the input ``tensordict`` shape dictates the number
325
# of simulators run simultaneously. In other contexts, the initial
326
# random state's shape will depend upon the environment batch-size instead.
327
th = (
328
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
329
* (high_th - low_th)
330
+ low_th
331
)
332
thdot = (
333
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
334
* (high_thdot - low_thdot)
335
+ low_thdot
336
)
337
out = TensorDict(
338
{
339
"th": th,
340
"thdot": thdot,
341
"params": tensordict["params"],
342
},
343
batch_size=tensordict.shape,
344
)
345
return out
346
347
348
######################################################################
349
# Environment metadata: ``env.*_spec``
350
# ------------------------------------
351
#
352
# The specs define the input and output domain of the environment.
353
# It is important that the specs accurately define the tensors that will be
354
# received at runtime, as they are often used to carry information about
355
# environments in multiprocessing and distributed settings. They can also be
356
# used to instantiate lazily defined neural networks and test scripts without
357
# actually querying the environment (which can be costly with real-world
358
# physical systems for instance).
359
#
360
# There are four specs that we must code in our environment:
361
#
362
# * :obj:`EnvBase.observation_spec`: This will be a :class:`~torchrl.data.CompositeSpec`
363
# instance where each key is an observation (a :class:`CompositeSpec` can be
364
# viewed as a dictionary of specs).
365
# * :obj:`EnvBase.action_spec`: It can be any type of spec, but it is required
366
# that it corresponds to the ``"action"`` entry in the input ``tensordict``;
367
# * :obj:`EnvBase.reward_spec`: provides information about the reward space;
368
# * :obj:`EnvBase.done_spec`: provides information about the space of the done
369
# flag.
370
#
371
# TorchRL specs are organized in two general containers: ``input_spec`` which
372
# contains the specs of the information that the step function reads (divided
373
# between ``action_spec`` containing the action and ``state_spec`` containing
374
# all the rest), and ``output_spec`` which encodes the specs that the
375
# step outputs (``observation_spec``, ``reward_spec`` and ``done_spec``).
376
# In general, you should not interact directly with ``output_spec`` and
377
# ``input_spec`` but only with their content: ``observation_spec``,
378
# ``reward_spec``, ``done_spec``, ``action_spec`` and ``state_spec``.
379
# The reason if that the specs are organized in a non-trivial way
380
# within ``output_spec`` and
381
# ``input_spec`` and neither of these should be directly modified.
382
#
383
# In other words, the ``observation_spec`` and related properties are
384
# convenient shortcuts to the content of the output and input spec containers.
385
#
386
# TorchRL offers multiple :class:`~torchrl.data.TensorSpec`
387
# `subclasses <https://pytorch.org/rl/reference/data.html#tensorspec>`_ to
388
# encode the environment's input and output characteristics.
389
#
390
# Specs shape
391
# ^^^^^^^^^^^
392
#
393
# The environment specs leading dimensions must match the
394
# environment batch-size. This is done to enforce that every component of an
395
# environment (including its transforms) have an accurate representation of
396
# the expected input and output shapes. This is something that should be
397
# accurately coded in stateful settings.
398
#
399
# For non batch-locked environments, such as the one in our example (see below),
400
# this is irrelevant as the environment batch size will most likely be empty.
401
#
402
403
404
def _make_spec(self, td_params):
405
# Under the hood, this will populate self.output_spec["observation"]
406
self.observation_spec = CompositeSpec(
407
th=BoundedTensorSpec(
408
low=-torch.pi,
409
high=torch.pi,
410
shape=(),
411
dtype=torch.float32,
412
),
413
thdot=BoundedTensorSpec(
414
low=-td_params["params", "max_speed"],
415
high=td_params["params", "max_speed"],
416
shape=(),
417
dtype=torch.float32,
418
),
419
# we need to add the ``params`` to the observation specs, as we want
420
# to pass it at each step during a rollout
421
params=make_composite_from_td(td_params["params"]),
422
shape=(),
423
)
424
# since the environment is stateless, we expect the previous output as input.
425
# For this, ``EnvBase`` expects some state_spec to be available
426
self.state_spec = self.observation_spec.clone()
427
# action-spec will be automatically wrapped in input_spec when
428
# `self.action_spec = spec` will be called supported
429
self.action_spec = BoundedTensorSpec(
430
low=-td_params["params", "max_torque"],
431
high=td_params["params", "max_torque"],
432
shape=(1,),
433
dtype=torch.float32,
434
)
435
self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
436
437
438
def make_composite_from_td(td):
439
# custom function to convert a ``tensordict`` in a similar spec structure
440
# of unbounded values.
441
composite = CompositeSpec(
442
{
443
key: make_composite_from_td(tensor)
444
if isinstance(tensor, TensorDictBase)
445
else UnboundedContinuousTensorSpec(
446
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
447
)
448
for key, tensor in td.items()
449
},
450
shape=td.shape,
451
)
452
return composite
453
454
455
######################################################################
456
# Reproducible experiments: seeding
457
# ---------------------------------
458
#
459
# Seeding an environment is a common operation when initializing an experiment.
460
# The only goal of :func:`EnvBase._set_seed` is to set the seed of the contained
461
# simulator. If possible, this operation should not call ``reset()`` or interact
462
# with the environment execution. The parent :func:`EnvBase.set_seed` method
463
# incorporates a mechanism that allows seeding multiple environments with a
464
# different pseudo-random and reproducible seed.
465
#
466
467
468
def _set_seed(self, seed: Optional[int]):
469
rng = torch.manual_seed(seed)
470
self.rng = rng
471
472
473
######################################################################
474
# Wrapping things together: the :class:`~torchrl.envs.EnvBase` class
475
# ------------------------------------------------------------------
476
#
477
# We can finally put together the pieces and design our environment class.
478
# The specs initialization needs to be performed during the environment
479
# construction, so we must take care of calling the :func:`_make_spec` method
480
# within :func:`PendulumEnv.__init__`.
481
#
482
# We add a static method :meth:`PendulumEnv.gen_params` which deterministically
483
# generates a set of hyperparameters to be used during execution:
484
#
485
486
487
def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
488
"""Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
489
if batch_size is None:
490
batch_size = []
491
td = TensorDict(
492
{
493
"params": TensorDict(
494
{
495
"max_speed": 8,
496
"max_torque": 2.0,
497
"dt": 0.05,
498
"g": g,
499
"m": 1.0,
500
"l": 1.0,
501
},
502
[],
503
)
504
},
505
[],
506
)
507
if batch_size:
508
td = td.expand(batch_size).contiguous()
509
return td
510
511
512
######################################################################
513
# We define the environment as non-``batch_locked`` by turning the ``homonymous``
514
# attribute to ``False``. This means that we will **not** enforce the input
515
# ``tensordict`` to have a ``batch-size`` that matches the one of the environment.
516
#
517
# The following code will just put together the pieces we have coded above.
518
#
519
520
521
class PendulumEnv(EnvBase):
522
metadata = {
523
"render_modes": ["human", "rgb_array"],
524
"render_fps": 30,
525
}
526
batch_locked = False
527
528
def __init__(self, td_params=None, seed=None, device="cpu"):
529
if td_params is None:
530
td_params = self.gen_params()
531
532
super().__init__(device=device, batch_size=[])
533
self._make_spec(td_params)
534
if seed is None:
535
seed = torch.empty((), dtype=torch.int64).random_().item()
536
self.set_seed(seed)
537
538
# Helpers: _make_step and gen_params
539
gen_params = staticmethod(gen_params)
540
_make_spec = _make_spec
541
542
# Mandatory methods: _step, _reset and _set_seed
543
_reset = _reset
544
_step = staticmethod(_step)
545
_set_seed = _set_seed
546
547
548
######################################################################
549
# Testing our environment
550
# -----------------------
551
#
552
# TorchRL provides a simple function :func:`~torchrl.envs.utils.check_env_specs`
553
# to check that a (transformed) environment has an input/output structure that
554
# matches the one dictated by its specs.
555
# Let us try it out:
556
#
557
558
env = PendulumEnv()
559
check_env_specs(env)
560
561
######################################################################
562
# We can have a look at our specs to have a visual representation of the environment
563
# signature:
564
#
565
566
print("observation_spec:", env.observation_spec)
567
print("state_spec:", env.state_spec)
568
print("reward_spec:", env.reward_spec)
569
570
######################################################################
571
# We can execute a couple of commands too to check that the output structure
572
# matches what is expected.
573
574
td = env.reset()
575
print("reset tensordict", td)
576
577
######################################################################
578
# We can run the :func:`env.rand_step` to generate
579
# an action randomly from the ``action_spec`` domain. A ``tensordict`` containing
580
# the hyperparameters and the current state **must** be passed since our
581
# environment is stateless. In stateful contexts, ``env.rand_step()`` works
582
# perfectly too.
583
#
584
td = env.rand_step(td)
585
print("random step tensordict", td)
586
587
######################################################################
588
# Transforming an environment
589
# ---------------------------
590
#
591
# Writing environment transforms for stateless simulators is slightly more
592
# complicated than for stateful ones: transforming an output entry that needs
593
# to be read at the following iteration requires to apply the inverse transform
594
# before calling :func:`meth.step` at the next step.
595
# This is an ideal scenario to showcase all the features of TorchRL's
596
# transforms!
597
#
598
# For instance, in the following transformed environment we ``unsqueeze`` the entries
599
# ``["th", "thdot"]`` to be able to stack them along the last
600
# dimension. We also pass them as ``in_keys_inv`` to squeeze them back to their
601
# original shape once they are passed as input in the next iteration.
602
#
603
env = TransformedEnv(
604
env,
605
# ``Unsqueeze`` the observations that we will concatenate
606
UnsqueezeTransform(
607
unsqueeze_dim=-1,
608
in_keys=["th", "thdot"],
609
in_keys_inv=["th", "thdot"],
610
),
611
)
612
613
######################################################################
614
# Writing custom transforms
615
# ^^^^^^^^^^^^^^^^^^^^^^^^^
616
#
617
# TorchRL's transforms may not cover all the operations one wants to execute
618
# after an environment has been executed.
619
# Writing a transform does not require much effort. As for the environment
620
# design, there are two steps in writing a transform:
621
#
622
# - Getting the dynamics right (forward and inverse);
623
# - Adapting the environment specs.
624
#
625
# A transform can be used in two settings: on its own, it can be used as a
626
# :class:`~torch.nn.Module`. It can also be used appended to a
627
# :class:`~torchrl.envs.transforms.TransformedEnv`. The structure of the class allows to
628
# customize the behavior in the different contexts.
629
#
630
# A :class:`~torchrl.envs.transforms.Transform` skeleton can be summarized as follows:
631
#
632
# .. code-block::
633
#
634
# class Transform(nn.Module):
635
# def forward(self, tensordict):
636
# ...
637
# def _apply_transform(self, tensordict):
638
# ...
639
# def _step(self, tensordict):
640
# ...
641
# def _call(self, tensordict):
642
# ...
643
# def inv(self, tensordict):
644
# ...
645
# def _inv_apply_transform(self, tensordict):
646
# ...
647
#
648
# There are three entry points (:func:`forward`, :func:`_step` and :func:`inv`)
649
# which all receive :class:`tensordict.TensorDict` instances. The first two
650
# will eventually go through the keys indicated by :obj:`~tochrl.envs.transforms.Transform.in_keys`
651
# and call :meth:`~torchrl.envs.transforms.Transform._apply_transform` to each of these. The results will
652
# be written in the entries pointed by :obj:`Transform.out_keys` if provided
653
# (if not the ``in_keys`` will be updated with the transformed values).
654
# If inverse transforms need to be executed, a similar data flow will be
655
# executed but with the :func:`Transform.inv` and
656
# :func:`Transform._inv_apply_transform` methods and across the ``in_keys_inv``
657
# and ``out_keys_inv`` list of keys.
658
# The following figure summarized this flow for environments and replay
659
# buffers.
660
#
661
#
662
# Transform API
663
#
664
# In some cases, a transform will not work on a subset of keys in a unitary
665
# manner, but will execute some operation on the parent environment or
666
# work with the entire input ``tensordict``.
667
# In those cases, the :func:`_call` and :func:`forward` methods should be
668
# re-written, and the :func:`_apply_transform` method can be skipped.
669
#
670
# Let us code new transforms that will compute the ``sine`` and ``cosine``
671
# values of the position angle, as these values are more useful to us to learn
672
# a policy than the raw angle value:
673
674
675
class SinTransform(Transform):
676
def _apply_transform(self, obs: torch.Tensor) -> None:
677
return obs.sin()
678
679
# The transform must also modify the data at reset time
680
def _reset(
681
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
682
) -> TensorDictBase:
683
return self._call(tensordict_reset)
684
685
# _apply_to_composite will execute the observation spec transform across all
686
# in_keys/out_keys pairs and write the result in the observation_spec which
687
# is of type ``Composite``
688
@_apply_to_composite
689
def transform_observation_spec(self, observation_spec):
690
return BoundedTensorSpec(
691
low=-1,
692
high=1,
693
shape=observation_spec.shape,
694
dtype=observation_spec.dtype,
695
device=observation_spec.device,
696
)
697
698
699
class CosTransform(Transform):
700
def _apply_transform(self, obs: torch.Tensor) -> None:
701
return obs.cos()
702
703
# The transform must also modify the data at reset time
704
def _reset(
705
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
706
) -> TensorDictBase:
707
return self._call(tensordict_reset)
708
709
# _apply_to_composite will execute the observation spec transform across all
710
# in_keys/out_keys pairs and write the result in the observation_spec which
711
# is of type ``Composite``
712
@_apply_to_composite
713
def transform_observation_spec(self, observation_spec):
714
return BoundedTensorSpec(
715
low=-1,
716
high=1,
717
shape=observation_spec.shape,
718
dtype=observation_spec.dtype,
719
device=observation_spec.device,
720
)
721
722
723
t_sin = SinTransform(in_keys=["th"], out_keys=["sin"])
724
t_cos = CosTransform(in_keys=["th"], out_keys=["cos"])
725
env.append_transform(t_sin)
726
env.append_transform(t_cos)
727
728
######################################################################
729
# Concatenates the observations onto an "observation" entry.
730
# ``del_keys=False`` ensures that we keep these values for the next
731
# iteration.
732
cat_transform = CatTensors(
733
in_keys=["sin", "cos", "thdot"], dim=-1, out_key="observation", del_keys=False
734
)
735
env.append_transform(cat_transform)
736
737
######################################################################
738
# Once more, let us check that our environment specs match what is received:
739
check_env_specs(env)
740
741
######################################################################
742
# Executing a rollout
743
# -------------------
744
#
745
# Executing a rollout is a succession of simple steps:
746
#
747
# * reset the environment
748
# * while some condition is not met:
749
#
750
# * compute an action given a policy
751
# * execute a step given this action
752
# * collect the data
753
# * make a ``MDP`` step
754
#
755
# * gather the data and return
756
#
757
# These operations have been conveniently wrapped in the :meth:`~torchrl.envs.EnvBase.rollout`
758
# method, from which we provide a simplified version here below.
759
760
761
def simple_rollout(steps=100):
762
# preallocate:
763
data = TensorDict({}, [steps])
764
# reset
765
_data = env.reset()
766
for i in range(steps):
767
_data["action"] = env.action_spec.rand()
768
_data = env.step(_data)
769
data[i] = _data
770
_data = step_mdp(_data, keep_other=True)
771
return data
772
773
774
print("data from rollout:", simple_rollout(100))
775
776
######################################################################
777
# Batching computations
778
# ---------------------
779
#
780
# The last unexplored end of our tutorial is the ability that we have to
781
# batch computations in TorchRL. Because our environment does not
782
# make any assumptions regarding the input data shape, we can seamlessly
783
# execute it over batches of data. Even better: for non-batch-locked
784
# environments such as our Pendulum, we can change the batch size on the fly
785
# without recreating the environment.
786
# To do this, we just generate parameters with the desired shape.
787
#
788
789
batch_size = 10 # number of environments to be executed in batch
790
td = env.reset(env.gen_params(batch_size=[batch_size]))
791
print("reset (batch size of 10)", td)
792
td = env.rand_step(td)
793
print("rand step (batch size of 10)", td)
794
795
######################################################################
796
# Executing a rollout with a batch of data requires us to reset the environment
797
# out of the rollout function, since we need to define the batch_size
798
# dynamically and this is not supported by :meth:`~torchrl.envs.EnvBase.rollout`:
799
#
800
801
rollout = env.rollout(
802
3,
803
auto_reset=False, # we're executing the reset out of the ``rollout`` call
804
tensordict=env.reset(env.gen_params(batch_size=[batch_size])),
805
)
806
print("rollout of len 3 (batch size of 10):", rollout)
807
808
809
######################################################################
810
# Training a simple policy
811
# ------------------------
812
#
813
# In this example, we will train a simple policy using the reward as a
814
# differentiable objective, such as a negative loss.
815
# We will take advantage of the fact that our dynamic system is fully
816
# differentiable to backpropagate through the trajectory return and adjust the
817
# weights of our policy to maximize this value directly. Of course, in many
818
# settings many of the assumptions we make do not hold, such as
819
# differentiable system and full access to the underlying mechanics.
820
#
821
# Still, this is a very simple example that showcases how a training loop can
822
# be coded with a custom environment in TorchRL.
823
#
824
# Let us first write the policy network:
825
#
826
torch.manual_seed(0)
827
env.set_seed(0)
828
829
net = nn.Sequential(
830
nn.LazyLinear(64),
831
nn.Tanh(),
832
nn.LazyLinear(64),
833
nn.Tanh(),
834
nn.LazyLinear(64),
835
nn.Tanh(),
836
nn.LazyLinear(1),
837
)
838
policy = TensorDictModule(
839
net,
840
in_keys=["observation"],
841
out_keys=["action"],
842
)
843
844
######################################################################
845
# and our optimizer:
846
#
847
848
optim = torch.optim.Adam(policy.parameters(), lr=2e-3)
849
850
######################################################################
851
# Training loop
852
# ^^^^^^^^^^^^^
853
#
854
# We will successively:
855
#
856
# * generate a trajectory
857
# * sum the rewards
858
# * backpropagate through the graph defined by these operations
859
# * clip the gradient norm and make an optimization step
860
# * repeat
861
#
862
# At the end of the training loop, we should have a final reward close to 0
863
# which demonstrates that the pendulum is upward and still as desired.
864
#
865
batch_size = 32
866
pbar = tqdm.tqdm(range(20_000 // batch_size))
867
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000)
868
logs = defaultdict(list)
869
870
for _ in pbar:
871
init_td = env.reset(env.gen_params(batch_size=[batch_size]))
872
rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)
873
traj_return = rollout["next", "reward"].mean()
874
(-traj_return).backward()
875
gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
876
optim.step()
877
optim.zero_grad()
878
pbar.set_description(
879
f"reward: {traj_return: 4.4f}, "
880
f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
881
)
882
logs["return"].append(traj_return.item())
883
logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
884
scheduler.step()
885
886
887
def plot():
888
import matplotlib
889
from matplotlib import pyplot as plt
890
891
is_ipython = "inline" in matplotlib.get_backend()
892
if is_ipython:
893
from IPython import display
894
895
with plt.ion():
896
plt.figure(figsize=(10, 5))
897
plt.subplot(1, 2, 1)
898
plt.plot(logs["return"])
899
plt.title("returns")
900
plt.xlabel("iteration")
901
plt.subplot(1, 2, 2)
902
plt.plot(logs["last_reward"])
903
plt.title("last reward")
904
plt.xlabel("iteration")
905
if is_ipython:
906
display.display(plt.gcf())
907
display.clear_output(wait=True)
908
plt.show()
909
910
911
plot()
912
913
914
######################################################################
915
# Conclusion
916
# ----------
917
#
918
# In this tutorial, we have learned how to code a stateless environment from
919
# scratch. We touched the subjects of:
920
#
921
# * The four essential components that need to be taken care of when coding
922
# an environment (``step``, ``reset``, seeding and building specs).
923
# We saw how these methods and classes interact with the
924
# :class:`~tensordict.TensorDict` class;
925
# * How to test that an environment is properly coded using
926
# :func:`~torchrl.envs.utils.check_env_specs`;
927
# * How to append transforms in the context of stateless environments and how
928
# to write custom transformations;
929
# * How to train a policy on a fully differentiable simulator.
930
#
931
932