Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/pendulum.py
Views: 712
# -*- coding: utf-8 -*-12"""3Pendulum: Writing your environment and transforms with TorchRL4==============================================================56**Author**: `Vincent Moens <https://github.com/vmoens>`_78Creating an environment (a simulator or an interface to a physical control system)9is an integrative part of reinforcement learning and control engineering.1011TorchRL provides a set of tools to do this in multiple contexts.12This tutorial demonstrates how to use PyTorch and TorchRL code a pendulum13simulator from the ground up.14It is freely inspired by the Pendulum-v1 implementation from `OpenAI-Gym/Farama-Gymnasium15control library <https://github.com/Farama-Foundation/Gymnasium>`__.1617.. figure:: /_static/img/pendulum.gif18:alt: Pendulum19:align: center2021Simple Pendulum2223Key learnings:2425- How to design an environment in TorchRL:26- Writing specs (input, observation and reward);27- Implementing behavior: seeding, reset and step.28- Transforming your environment inputs and outputs, and writing your own29transforms;30- How to use :class:`~tensordict.TensorDict` to carry arbitrary data structures31through the ``codebase``.3233In the process, we will touch three crucial components of TorchRL:3435* `environments <https://pytorch.org/rl/reference/envs.html>`__36* `transforms <https://pytorch.org/rl/reference/envs.html#transforms>`__37* `models (policy and value function) <https://pytorch.org/rl/reference/modules.html>`__3839"""4041######################################################################42# To give a sense of what can be achieved with TorchRL's environments, we will43# be designing a *stateless* environment. While stateful environments keep track of44# the latest physical state encountered and rely on this to simulate the state-to-state45# transition, stateless environments expect the current state to be provided to46# them at each step, along with the action undertaken. TorchRL supports both47# types of environments, but stateless environments are more generic and hence48# cover a broader range of features of the environment API in TorchRL.49#50# Modeling stateless environments gives users full control over the input and51# outputs of the simulator: one can reset an experiment at any stage or actively52# modify the dynamics from the outside. However, it assumes that we have some control53# over a task, which may not always be the case: solving a problem where we cannot54# control the current state is more challenging but has a much wider set of applications.55#56# Another advantage of stateless environments is that they can enable57# batched execution of transition simulations. If the backend and the58# implementation allow it, an algebraic operation can be executed seamlessly on59# scalars, vectors, or tensors. This tutorial gives such examples.60#61# This tutorial will be structured as follows:62#63# * We will first get acquainted with the environment properties:64# its shape (``batch_size``), its methods (mainly :meth:`~torchrl.envs.EnvBase.step`,65# :meth:`~torchrl.envs.EnvBase.reset` and :meth:`~torchrl.envs.EnvBase.set_seed`)66# and finally its specs.67# * After having coded our simulator, we will demonstrate how it can be used68# during training with transforms.69# * We will explore new avenues that follow from the TorchRL's API,70# including: the possibility of transforming inputs, the vectorized execution71# of the simulation and the possibility of backpropagation through the72# simulation graph.73# * Finally, we will train a simple policy to solve the system we implemented.74#7576# sphinx_gallery_start_ignore77import warnings7879warnings.filterwarnings("ignore")80from torch import multiprocessing8182# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside83# `__main__` method call, but for the easy of reading the code switch to fork84# which is also a default spawn method in Google's Colaboratory85try:86multiprocessing.set_start_method("fork")87except RuntimeError:88pass8990# sphinx_gallery_end_ignore9192from collections import defaultdict93from typing import Optional9495import numpy as np96import torch97import tqdm98from tensordict import TensorDict, TensorDictBase99from tensordict.nn import TensorDictModule100from torch import nn101102from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec103from torchrl.envs import (104CatTensors,105EnvBase,106Transform,107TransformedEnv,108UnsqueezeTransform,109)110from torchrl.envs.transforms.transforms import _apply_to_composite111from torchrl.envs.utils import check_env_specs, step_mdp112113DEFAULT_X = np.pi114DEFAULT_Y = 1.0115116######################################################################117# There are four things you must take care of when designing a new environment118# class:119#120# * :meth:`EnvBase._reset`, which codes for the resetting of the simulator121# at a (potentially random) initial state;122# * :meth:`EnvBase._step` which codes for the state transition dynamic;123# * :meth:`EnvBase._set_seed`` which implements the seeding mechanism;124# * the environment specs.125#126# Let us first describe the problem at hand: we would like to model a simple127# pendulum over which we can control the torque applied on its fixed point.128# Our goal is to place the pendulum in upward position (angular position at 0129# by convention) and having it standing still in that position.130# To design our dynamic system, we need to define two equations: the motion131# equation following an action (the torque applied) and the reward equation132# that will constitute our objective function.133#134# For the motion equation, we will update the angular velocity following:135#136# .. math::137#138# \dot{\theta}_{t+1} = \dot{\theta}_t + (3 * g / (2 * L) * \sin(\theta_t) + 3 / (m * L^2) * u) * dt139#140# where :math:`\dot{\theta}` is the angular velocity in rad/sec, :math:`g` is the141# gravitational force, :math:`L` is the pendulum length, :math:`m` is its mass,142# :math:`\theta` is its angular position and :math:`u` is the torque. The143# angular position is then updated according to144#145# .. math::146#147# \theta_{t+1} = \theta_{t} + \dot{\theta}_{t+1} dt148#149# We define our reward as150#151# .. math::152#153# r = -(\theta^2 + 0.1 * \dot{\theta}^2 + 0.001 * u^2)154#155# which will be maximized when the angle is close to 0 (pendulum in upward156# position), the angular velocity is close to 0 (no motion) and the torque is157# 0 too.158#159# Coding the effect of an action: :func:`~torchrl.envs.EnvBase._step`160# -------------------------------------------------------------------161#162# The step method is the first thing to consider, as it will encode163# the simulation that is of interest to us. In TorchRL, the164# :class:`~torchrl.envs.EnvBase` class has a :meth:`EnvBase.step`165# method that receives a :class:`tensordict.TensorDict`166# instance with an ``"action"`` entry indicating what action is to be taken.167#168# To facilitate the reading and writing from that ``tensordict`` and to make sure169# that the keys are consistent with what's expected from the library, the170# simulation part has been delegated to a private abstract method :meth:`_step`171# which reads input data from a ``tensordict``, and writes a *new* ``tensordict``172# with the output data.173#174# The :func:`_step` method should do the following:175#176# 1. Read the input keys (such as ``"action"``) and execute the simulation177# based on these;178# 2. Retrieve observations, done state and reward;179# 3. Write the set of observation values along with the reward and done state180# at the corresponding entries in a new :class:`TensorDict`.181#182# Next, the :meth:`~torchrl.envs.EnvBase.step` method will merge the output183# of :meth:`~torchrl.envs.EnvBase.step` in the input ``tensordict`` to enforce184# input/output consistency.185#186# Typically, for stateful environments, this will look like this:187#188# .. code-block::189#190# >>> policy(env.reset())191# >>> print(tensordict)192# TensorDict(193# fields={194# action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),195# done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),196# observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},197# batch_size=torch.Size([]),198# device=cpu,199# is_shared=False)200# >>> env.step(tensordict)201# >>> print(tensordict)202# TensorDict(203# fields={204# action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),205# done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),206# next: TensorDict(207# fields={208# done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),209# observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),210# reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},211# batch_size=torch.Size([]),212# device=cpu,213# is_shared=False),214# observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},215# batch_size=torch.Size([]),216# device=cpu,217# is_shared=False)218#219# Notice that the root ``tensordict`` has not changed, the only modification is the220# appearance of a new ``"next"`` entry that contains the new information.221#222# In the Pendulum example, our :meth:`_step` method will read the relevant223# entries from the input ``tensordict`` and compute the position and velocity of224# the pendulum after the force encoded by the ``"action"`` key has been applied225# onto it. We compute the new angular position of the pendulum226# ``"new_th"`` as the result of the previous position ``"th"`` plus the new227# velocity ``"new_thdot"`` over a time interval ``dt``.228#229# Since our goal is to turn the pendulum up and maintain it still in that230# position, our ``cost`` (negative reward) function is lower for positions231# close to the target and low speeds.232# Indeed, we want to discourage positions that are far from being "upward"233# and/or speeds that are far from 0.234#235# In our example, :meth:`EnvBase._step` is encoded as a static method since our236# environment is stateless. In stateful settings, the ``self`` argument is237# needed as the state needs to be read from the environment.238#239240241def _step(tensordict):242th, thdot = tensordict["th"], tensordict["thdot"] # th := theta243244g_force = tensordict["params", "g"]245mass = tensordict["params", "m"]246length = tensordict["params", "l"]247dt = tensordict["params", "dt"]248u = tensordict["action"].squeeze(-1)249u = u.clamp(-tensordict["params", "max_torque"], tensordict["params", "max_torque"])250costs = angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)251252new_thdot = (253thdot254+ (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u) * dt255)256new_thdot = new_thdot.clamp(257-tensordict["params", "max_speed"], tensordict["params", "max_speed"]258)259new_th = th + new_thdot * dt260reward = -costs.view(*tensordict.shape, 1)261done = torch.zeros_like(reward, dtype=torch.bool)262out = TensorDict(263{264"th": new_th,265"thdot": new_thdot,266"params": tensordict["params"],267"reward": reward,268"done": done,269},270tensordict.shape,271)272return out273274275def angle_normalize(x):276return ((x + torch.pi) % (2 * torch.pi)) - torch.pi277278279######################################################################280# Resetting the simulator: :func:`~torchrl.envs.EnvBase._reset`281# -------------------------------------------------------------282#283# The second method we need to care about is the284# :meth:`~torchrl.envs.EnvBase._reset` method. Like285# :meth:`~torchrl.envs.EnvBase._step`, it should write the observation entries286# and possibly a done state in the ``tensordict`` it outputs (if the done state is287# omitted, it will be filled as ``False`` by the parent method288# :meth:`~torchrl.envs.EnvBase.reset`). In some contexts, it is required that289# the ``_reset`` method receives a command from the function that called290# it (for example, in multi-agent settings we may want to indicate which agents need291# to be reset). This is why the :meth:`~torchrl.envs.EnvBase._reset` method292# also expects a ``tensordict`` as input, albeit it may perfectly be empty or293# ``None``.294#295# The parent :meth:`EnvBase.reset` does some simple checks like the296# :meth:`EnvBase.step` does, such as making sure that a ``"done"`` state297# is returned in the output ``tensordict`` and that the shapes match what is298# expected from the specs.299#300# For us, the only important thing to consider is whether301# :meth:`EnvBase._reset` contains all the expected observations. Once more,302# since we are working with a stateless environment, we pass the configuration303# of the pendulum in a nested ``tensordict`` named ``"params"``.304#305# In this example, we do not pass a done state as this is not mandatory306# for :meth:`_reset` and our environment is non-terminating, so we always307# expect it to be ``False``.308#309310311def _reset(self, tensordict):312if tensordict is None or tensordict.is_empty():313# if no ``tensordict`` is passed, we generate a single set of hyperparameters314# Otherwise, we assume that the input ``tensordict`` contains all the relevant315# parameters to get started.316tensordict = self.gen_params(batch_size=self.batch_size)317318high_th = torch.tensor(DEFAULT_X, device=self.device)319high_thdot = torch.tensor(DEFAULT_Y, device=self.device)320low_th = -high_th321low_thdot = -high_thdot322323# for non batch-locked environments, the input ``tensordict`` shape dictates the number324# of simulators run simultaneously. In other contexts, the initial325# random state's shape will depend upon the environment batch-size instead.326th = (327torch.rand(tensordict.shape, generator=self.rng, device=self.device)328* (high_th - low_th)329+ low_th330)331thdot = (332torch.rand(tensordict.shape, generator=self.rng, device=self.device)333* (high_thdot - low_thdot)334+ low_thdot335)336out = TensorDict(337{338"th": th,339"thdot": thdot,340"params": tensordict["params"],341},342batch_size=tensordict.shape,343)344return out345346347######################################################################348# Environment metadata: ``env.*_spec``349# ------------------------------------350#351# The specs define the input and output domain of the environment.352# It is important that the specs accurately define the tensors that will be353# received at runtime, as they are often used to carry information about354# environments in multiprocessing and distributed settings. They can also be355# used to instantiate lazily defined neural networks and test scripts without356# actually querying the environment (which can be costly with real-world357# physical systems for instance).358#359# There are four specs that we must code in our environment:360#361# * :obj:`EnvBase.observation_spec`: This will be a :class:`~torchrl.data.CompositeSpec`362# instance where each key is an observation (a :class:`CompositeSpec` can be363# viewed as a dictionary of specs).364# * :obj:`EnvBase.action_spec`: It can be any type of spec, but it is required365# that it corresponds to the ``"action"`` entry in the input ``tensordict``;366# * :obj:`EnvBase.reward_spec`: provides information about the reward space;367# * :obj:`EnvBase.done_spec`: provides information about the space of the done368# flag.369#370# TorchRL specs are organized in two general containers: ``input_spec`` which371# contains the specs of the information that the step function reads (divided372# between ``action_spec`` containing the action and ``state_spec`` containing373# all the rest), and ``output_spec`` which encodes the specs that the374# step outputs (``observation_spec``, ``reward_spec`` and ``done_spec``).375# In general, you should not interact directly with ``output_spec`` and376# ``input_spec`` but only with their content: ``observation_spec``,377# ``reward_spec``, ``done_spec``, ``action_spec`` and ``state_spec``.378# The reason if that the specs are organized in a non-trivial way379# within ``output_spec`` and380# ``input_spec`` and neither of these should be directly modified.381#382# In other words, the ``observation_spec`` and related properties are383# convenient shortcuts to the content of the output and input spec containers.384#385# TorchRL offers multiple :class:`~torchrl.data.TensorSpec`386# `subclasses <https://pytorch.org/rl/reference/data.html#tensorspec>`_ to387# encode the environment's input and output characteristics.388#389# Specs shape390# ^^^^^^^^^^^391#392# The environment specs leading dimensions must match the393# environment batch-size. This is done to enforce that every component of an394# environment (including its transforms) have an accurate representation of395# the expected input and output shapes. This is something that should be396# accurately coded in stateful settings.397#398# For non batch-locked environments, such as the one in our example (see below),399# this is irrelevant as the environment batch size will most likely be empty.400#401402403def _make_spec(self, td_params):404# Under the hood, this will populate self.output_spec["observation"]405self.observation_spec = CompositeSpec(406th=BoundedTensorSpec(407low=-torch.pi,408high=torch.pi,409shape=(),410dtype=torch.float32,411),412thdot=BoundedTensorSpec(413low=-td_params["params", "max_speed"],414high=td_params["params", "max_speed"],415shape=(),416dtype=torch.float32,417),418# we need to add the ``params`` to the observation specs, as we want419# to pass it at each step during a rollout420params=make_composite_from_td(td_params["params"]),421shape=(),422)423# since the environment is stateless, we expect the previous output as input.424# For this, ``EnvBase`` expects some state_spec to be available425self.state_spec = self.observation_spec.clone()426# action-spec will be automatically wrapped in input_spec when427# `self.action_spec = spec` will be called supported428self.action_spec = BoundedTensorSpec(429low=-td_params["params", "max_torque"],430high=td_params["params", "max_torque"],431shape=(1,),432dtype=torch.float32,433)434self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))435436437def make_composite_from_td(td):438# custom function to convert a ``tensordict`` in a similar spec structure439# of unbounded values.440composite = CompositeSpec(441{442key: make_composite_from_td(tensor)443if isinstance(tensor, TensorDictBase)444else UnboundedContinuousTensorSpec(445dtype=tensor.dtype, device=tensor.device, shape=tensor.shape446)447for key, tensor in td.items()448},449shape=td.shape,450)451return composite452453454######################################################################455# Reproducible experiments: seeding456# ---------------------------------457#458# Seeding an environment is a common operation when initializing an experiment.459# The only goal of :func:`EnvBase._set_seed` is to set the seed of the contained460# simulator. If possible, this operation should not call ``reset()`` or interact461# with the environment execution. The parent :func:`EnvBase.set_seed` method462# incorporates a mechanism that allows seeding multiple environments with a463# different pseudo-random and reproducible seed.464#465466467def _set_seed(self, seed: Optional[int]):468rng = torch.manual_seed(seed)469self.rng = rng470471472######################################################################473# Wrapping things together: the :class:`~torchrl.envs.EnvBase` class474# ------------------------------------------------------------------475#476# We can finally put together the pieces and design our environment class.477# The specs initialization needs to be performed during the environment478# construction, so we must take care of calling the :func:`_make_spec` method479# within :func:`PendulumEnv.__init__`.480#481# We add a static method :meth:`PendulumEnv.gen_params` which deterministically482# generates a set of hyperparameters to be used during execution:483#484485486def gen_params(g=10.0, batch_size=None) -> TensorDictBase:487"""Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""488if batch_size is None:489batch_size = []490td = TensorDict(491{492"params": TensorDict(493{494"max_speed": 8,495"max_torque": 2.0,496"dt": 0.05,497"g": g,498"m": 1.0,499"l": 1.0,500},501[],502)503},504[],505)506if batch_size:507td = td.expand(batch_size).contiguous()508return td509510511######################################################################512# We define the environment as non-``batch_locked`` by turning the ``homonymous``513# attribute to ``False``. This means that we will **not** enforce the input514# ``tensordict`` to have a ``batch-size`` that matches the one of the environment.515#516# The following code will just put together the pieces we have coded above.517#518519520class PendulumEnv(EnvBase):521metadata = {522"render_modes": ["human", "rgb_array"],523"render_fps": 30,524}525batch_locked = False526527def __init__(self, td_params=None, seed=None, device="cpu"):528if td_params is None:529td_params = self.gen_params()530531super().__init__(device=device, batch_size=[])532self._make_spec(td_params)533if seed is None:534seed = torch.empty((), dtype=torch.int64).random_().item()535self.set_seed(seed)536537# Helpers: _make_step and gen_params538gen_params = staticmethod(gen_params)539_make_spec = _make_spec540541# Mandatory methods: _step, _reset and _set_seed542_reset = _reset543_step = staticmethod(_step)544_set_seed = _set_seed545546547######################################################################548# Testing our environment549# -----------------------550#551# TorchRL provides a simple function :func:`~torchrl.envs.utils.check_env_specs`552# to check that a (transformed) environment has an input/output structure that553# matches the one dictated by its specs.554# Let us try it out:555#556557env = PendulumEnv()558check_env_specs(env)559560######################################################################561# We can have a look at our specs to have a visual representation of the environment562# signature:563#564565print("observation_spec:", env.observation_spec)566print("state_spec:", env.state_spec)567print("reward_spec:", env.reward_spec)568569######################################################################570# We can execute a couple of commands too to check that the output structure571# matches what is expected.572573td = env.reset()574print("reset tensordict", td)575576######################################################################577# We can run the :func:`env.rand_step` to generate578# an action randomly from the ``action_spec`` domain. A ``tensordict`` containing579# the hyperparameters and the current state **must** be passed since our580# environment is stateless. In stateful contexts, ``env.rand_step()`` works581# perfectly too.582#583td = env.rand_step(td)584print("random step tensordict", td)585586######################################################################587# Transforming an environment588# ---------------------------589#590# Writing environment transforms for stateless simulators is slightly more591# complicated than for stateful ones: transforming an output entry that needs592# to be read at the following iteration requires to apply the inverse transform593# before calling :func:`meth.step` at the next step.594# This is an ideal scenario to showcase all the features of TorchRL's595# transforms!596#597# For instance, in the following transformed environment we ``unsqueeze`` the entries598# ``["th", "thdot"]`` to be able to stack them along the last599# dimension. We also pass them as ``in_keys_inv`` to squeeze them back to their600# original shape once they are passed as input in the next iteration.601#602env = TransformedEnv(603env,604# ``Unsqueeze`` the observations that we will concatenate605UnsqueezeTransform(606unsqueeze_dim=-1,607in_keys=["th", "thdot"],608in_keys_inv=["th", "thdot"],609),610)611612######################################################################613# Writing custom transforms614# ^^^^^^^^^^^^^^^^^^^^^^^^^615#616# TorchRL's transforms may not cover all the operations one wants to execute617# after an environment has been executed.618# Writing a transform does not require much effort. As for the environment619# design, there are two steps in writing a transform:620#621# - Getting the dynamics right (forward and inverse);622# - Adapting the environment specs.623#624# A transform can be used in two settings: on its own, it can be used as a625# :class:`~torch.nn.Module`. It can also be used appended to a626# :class:`~torchrl.envs.transforms.TransformedEnv`. The structure of the class allows to627# customize the behavior in the different contexts.628#629# A :class:`~torchrl.envs.transforms.Transform` skeleton can be summarized as follows:630#631# .. code-block::632#633# class Transform(nn.Module):634# def forward(self, tensordict):635# ...636# def _apply_transform(self, tensordict):637# ...638# def _step(self, tensordict):639# ...640# def _call(self, tensordict):641# ...642# def inv(self, tensordict):643# ...644# def _inv_apply_transform(self, tensordict):645# ...646#647# There are three entry points (:func:`forward`, :func:`_step` and :func:`inv`)648# which all receive :class:`tensordict.TensorDict` instances. The first two649# will eventually go through the keys indicated by :obj:`~tochrl.envs.transforms.Transform.in_keys`650# and call :meth:`~torchrl.envs.transforms.Transform._apply_transform` to each of these. The results will651# be written in the entries pointed by :obj:`Transform.out_keys` if provided652# (if not the ``in_keys`` will be updated with the transformed values).653# If inverse transforms need to be executed, a similar data flow will be654# executed but with the :func:`Transform.inv` and655# :func:`Transform._inv_apply_transform` methods and across the ``in_keys_inv``656# and ``out_keys_inv`` list of keys.657# The following figure summarized this flow for environments and replay658# buffers.659#660#661# Transform API662#663# In some cases, a transform will not work on a subset of keys in a unitary664# manner, but will execute some operation on the parent environment or665# work with the entire input ``tensordict``.666# In those cases, the :func:`_call` and :func:`forward` methods should be667# re-written, and the :func:`_apply_transform` method can be skipped.668#669# Let us code new transforms that will compute the ``sine`` and ``cosine``670# values of the position angle, as these values are more useful to us to learn671# a policy than the raw angle value:672673674class SinTransform(Transform):675def _apply_transform(self, obs: torch.Tensor) -> None:676return obs.sin()677678# The transform must also modify the data at reset time679def _reset(680self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase681) -> TensorDictBase:682return self._call(tensordict_reset)683684# _apply_to_composite will execute the observation spec transform across all685# in_keys/out_keys pairs and write the result in the observation_spec which686# is of type ``Composite``687@_apply_to_composite688def transform_observation_spec(self, observation_spec):689return BoundedTensorSpec(690low=-1,691high=1,692shape=observation_spec.shape,693dtype=observation_spec.dtype,694device=observation_spec.device,695)696697698class CosTransform(Transform):699def _apply_transform(self, obs: torch.Tensor) -> None:700return obs.cos()701702# The transform must also modify the data at reset time703def _reset(704self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase705) -> TensorDictBase:706return self._call(tensordict_reset)707708# _apply_to_composite will execute the observation spec transform across all709# in_keys/out_keys pairs and write the result in the observation_spec which710# is of type ``Composite``711@_apply_to_composite712def transform_observation_spec(self, observation_spec):713return BoundedTensorSpec(714low=-1,715high=1,716shape=observation_spec.shape,717dtype=observation_spec.dtype,718device=observation_spec.device,719)720721722t_sin = SinTransform(in_keys=["th"], out_keys=["sin"])723t_cos = CosTransform(in_keys=["th"], out_keys=["cos"])724env.append_transform(t_sin)725env.append_transform(t_cos)726727######################################################################728# Concatenates the observations onto an "observation" entry.729# ``del_keys=False`` ensures that we keep these values for the next730# iteration.731cat_transform = CatTensors(732in_keys=["sin", "cos", "thdot"], dim=-1, out_key="observation", del_keys=False733)734env.append_transform(cat_transform)735736######################################################################737# Once more, let us check that our environment specs match what is received:738check_env_specs(env)739740######################################################################741# Executing a rollout742# -------------------743#744# Executing a rollout is a succession of simple steps:745#746# * reset the environment747# * while some condition is not met:748#749# * compute an action given a policy750# * execute a step given this action751# * collect the data752# * make a ``MDP`` step753#754# * gather the data and return755#756# These operations have been conveniently wrapped in the :meth:`~torchrl.envs.EnvBase.rollout`757# method, from which we provide a simplified version here below.758759760def simple_rollout(steps=100):761# preallocate:762data = TensorDict({}, [steps])763# reset764_data = env.reset()765for i in range(steps):766_data["action"] = env.action_spec.rand()767_data = env.step(_data)768data[i] = _data769_data = step_mdp(_data, keep_other=True)770return data771772773print("data from rollout:", simple_rollout(100))774775######################################################################776# Batching computations777# ---------------------778#779# The last unexplored end of our tutorial is the ability that we have to780# batch computations in TorchRL. Because our environment does not781# make any assumptions regarding the input data shape, we can seamlessly782# execute it over batches of data. Even better: for non-batch-locked783# environments such as our Pendulum, we can change the batch size on the fly784# without recreating the environment.785# To do this, we just generate parameters with the desired shape.786#787788batch_size = 10 # number of environments to be executed in batch789td = env.reset(env.gen_params(batch_size=[batch_size]))790print("reset (batch size of 10)", td)791td = env.rand_step(td)792print("rand step (batch size of 10)", td)793794######################################################################795# Executing a rollout with a batch of data requires us to reset the environment796# out of the rollout function, since we need to define the batch_size797# dynamically and this is not supported by :meth:`~torchrl.envs.EnvBase.rollout`:798#799800rollout = env.rollout(8013,802auto_reset=False, # we're executing the reset out of the ``rollout`` call803tensordict=env.reset(env.gen_params(batch_size=[batch_size])),804)805print("rollout of len 3 (batch size of 10):", rollout)806807808######################################################################809# Training a simple policy810# ------------------------811#812# In this example, we will train a simple policy using the reward as a813# differentiable objective, such as a negative loss.814# We will take advantage of the fact that our dynamic system is fully815# differentiable to backpropagate through the trajectory return and adjust the816# weights of our policy to maximize this value directly. Of course, in many817# settings many of the assumptions we make do not hold, such as818# differentiable system and full access to the underlying mechanics.819#820# Still, this is a very simple example that showcases how a training loop can821# be coded with a custom environment in TorchRL.822#823# Let us first write the policy network:824#825torch.manual_seed(0)826env.set_seed(0)827828net = nn.Sequential(829nn.LazyLinear(64),830nn.Tanh(),831nn.LazyLinear(64),832nn.Tanh(),833nn.LazyLinear(64),834nn.Tanh(),835nn.LazyLinear(1),836)837policy = TensorDictModule(838net,839in_keys=["observation"],840out_keys=["action"],841)842843######################################################################844# and our optimizer:845#846847optim = torch.optim.Adam(policy.parameters(), lr=2e-3)848849######################################################################850# Training loop851# ^^^^^^^^^^^^^852#853# We will successively:854#855# * generate a trajectory856# * sum the rewards857# * backpropagate through the graph defined by these operations858# * clip the gradient norm and make an optimization step859# * repeat860#861# At the end of the training loop, we should have a final reward close to 0862# which demonstrates that the pendulum is upward and still as desired.863#864batch_size = 32865pbar = tqdm.tqdm(range(20_000 // batch_size))866scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000)867logs = defaultdict(list)868869for _ in pbar:870init_td = env.reset(env.gen_params(batch_size=[batch_size]))871rollout = env.rollout(100, policy, tensordict=init_td, auto_reset=False)872traj_return = rollout["next", "reward"].mean()873(-traj_return).backward()874gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)875optim.step()876optim.zero_grad()877pbar.set_description(878f"reward: {traj_return: 4.4f}, "879f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"880)881logs["return"].append(traj_return.item())882logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())883scheduler.step()884885886def plot():887import matplotlib888from matplotlib import pyplot as plt889890is_ipython = "inline" in matplotlib.get_backend()891if is_ipython:892from IPython import display893894with plt.ion():895plt.figure(figsize=(10, 5))896plt.subplot(1, 2, 1)897plt.plot(logs["return"])898plt.title("returns")899plt.xlabel("iteration")900plt.subplot(1, 2, 2)901plt.plot(logs["last_reward"])902plt.title("last reward")903plt.xlabel("iteration")904if is_ipython:905display.display(plt.gcf())906display.clear_output(wait=True)907plt.show()908909910plot()911912913######################################################################914# Conclusion915# ----------916#917# In this tutorial, we have learned how to code a stateless environment from918# scratch. We touched the subjects of:919#920# * The four essential components that need to be taken care of when coding921# an environment (``step``, ``reset``, seeding and building specs).922# We saw how these methods and classes interact with the923# :class:`~tensordict.TensorDict` class;924# * How to test that an environment is properly coded using925# :func:`~torchrl.envs.utils.check_env_specs`;926# * How to append transforms in the context of stateless environments and how927# to write custom transformations;928# * How to train a policy on a fully differentiable simulator.929#930931932