Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/reinforcement_ppo.py
Views: 712
# -*- coding: utf-8 -*-1"""2Reinforcement Learning (PPO) with TorchRL Tutorial3==================================================4**Author**: `Vincent Moens <https://github.com/vmoens>`_56This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy7network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium8control library <https://github.com/Farama-Foundation/Gymnasium>`__.910.. figure:: /_static/img/invpendulum.gif11:alt: Inverted pendulum1213Inverted pendulum1415Key learnings:1617- How to create an environment in TorchRL, transform its outputs, and collect data from this environment;18- How to make your classes talk to each other using :class:`~tensordict.TensorDict`;19- The basics of building your training loop with TorchRL:2021- How to compute the advantage signal for policy gradient methods;22- How to create a stochastic policy using a probabilistic neural network;23- How to create a dynamic replay buffer and sample from it without repetition.2425We will cover six crucial components of TorchRL:2627* `environments <https://pytorch.org/rl/reference/envs.html>`__28* `transforms <https://pytorch.org/rl/reference/envs.html#transforms>`__29* `models (policy and value function) <https://pytorch.org/rl/reference/modules.html>`__30* `loss modules <https://pytorch.org/rl/reference/objectives.html>`__31* `data collectors <https://pytorch.org/rl/reference/collectors.html>`__32* `replay buffers <https://pytorch.org/rl/reference/data.html#replay-buffers>`__3334"""3536######################################################################37# If you are running this in Google Colab, make sure you install the following dependencies:38#39# .. code-block:: bash40#41# !pip3 install torchrl42# !pip3 install gym[mujoco]43# !pip3 install tqdm44#45# Proximal Policy Optimization (PPO) is a policy-gradient algorithm where a46# batch of data is being collected and directly consumed to train the policy to maximise47# the expected return given some proximality constraints. You can think of it48# as a sophisticated version of `REINFORCE <https://link.springer.com/content/pdf/10.1007/BF00992696.pdf>`_,49# the foundational policy-optimization algorithm. For more information, see the50# `Proximal Policy Optimization Algorithms <https://arxiv.org/abs/1707.06347>`_ paper.51#52# PPO is usually regarded as a fast and efficient method for online, on-policy53# reinforcement algorithm. TorchRL provides a loss-module that does all the work54# for you, so that you can rely on this implementation and focus on solving your55# problem rather than re-inventing the wheel every time you want to train a policy.56#57# For completeness, here is a brief overview of what the loss computes, even though58# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows:59# 1. we will sample a batch of data by playing the60# policy in the environment for a given number of steps.61# 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using62# a clipped version of the REINFORCE loss.63# 3. The clipping will put a pessimistic bound on our loss: lower return estimates will64# be favored compared to higher ones.65# The precise formula of the loss is:66#67# .. math::68#69# L(s,a,\theta_k,\theta) = \min\left(70# \frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)} A^{\pi_{\theta_k}}(s,a), \;\;71# g(\epsilon, A^{\pi_{\theta_k}}(s,a))72# \right),73#74# There are two components in that loss: in the first part of the minimum operator,75# we simply compute an importance-weighted version of the REINFORCE loss (for example, a76# REINFORCE loss that we have corrected for the fact that the current policy77# configuration lags the one that was used for the data collection).78# The second part of that minimum operator is a similar loss where we have clipped79# the ratios when they exceeded or were below a given pair of thresholds.80#81# This loss ensures that whether the advantage is positive or negative, policy82# updates that would produce significant shifts from the previous configuration83# are being discouraged.84#85# This tutorial is structured as follows:86#87# 1. First, we will define a set of hyperparameters we will be using for training.88#89# 2. Next, we will focus on creating our environment, or simulator, using TorchRL's90# wrappers and transforms.91#92# 3. Next, we will design the policy network and the value model,93# which is indispensable to the loss function. These modules will be used94# to configure our loss module.95#96# 4. Next, we will create the replay buffer and data loader.97#98# 5. Finally, we will run our training loop and analyze the results.99#100# Throughout this tutorial, we'll be using the :mod:`tensordict` library.101# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract102# what a module reads and writes and care less about the specific data103# description and more about the algorithm itself.104#105106import warnings107warnings.filterwarnings("ignore")108from torch import multiprocessing109110# sphinx_gallery_start_ignore111112# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside113# `__main__` method call, but for the easy of reading the code switch to fork114# which is also a default spawn method in Google's Colaboratory115try:116multiprocessing.set_start_method("fork")117except RuntimeError:118pass119120# sphinx_gallery_end_ignore121122from collections import defaultdict123124import matplotlib.pyplot as plt125import torch126from tensordict.nn import TensorDictModule127from tensordict.nn.distributions import NormalParamExtractor128from torch import nn129from torchrl.collectors import SyncDataCollector130from torchrl.data.replay_buffers import ReplayBuffer131from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement132from torchrl.data.replay_buffers.storages import LazyTensorStorage133from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,134TransformedEnv)135from torchrl.envs.libs.gym import GymEnv136from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type137from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator138from torchrl.objectives import ClipPPOLoss139from torchrl.objectives.value import GAE140from tqdm import tqdm141142######################################################################143# Define Hyperparameters144# ----------------------145#146# We set the hyperparameters for our algorithm. Depending on the resources147# available, one may choose to execute the policy on GPU or on another148# device.149# The ``frame_skip`` will control how for how many frames is a single150# action being executed. The rest of the arguments that count frames151# must be corrected for this value (since one environment step will152# actually return ``frame_skip`` frames).153#154155is_fork = multiprocessing.get_start_method() == "fork"156device = (157torch.device(0)158if torch.cuda.is_available() and not is_fork159else torch.device("cpu")160)161num_cells = 256 # number of cells in each layer i.e. output dim.162lr = 3e-4163max_grad_norm = 1.0164165######################################################################166# Data collection parameters167# ~~~~~~~~~~~~~~~~~~~~~~~~~~168#169# When collecting data, we will be able to choose how big each batch will be170# by defining a ``frames_per_batch`` parameter. We will also define how many171# frames (such as the number of interactions with the simulator) we will allow ourselves to172# use. In general, the goal of an RL algorithm is to learn to solve the task173# as fast as it can in terms of environment interactions: the lower the ``total_frames``174# the better.175#176frames_per_batch = 1000177# For a complete training, bring the number of frames up to 1M178total_frames = 50_000179180######################################################################181# PPO parameters182# ~~~~~~~~~~~~~~183#184# At each data collection (or batch collection) we will run the optimization185# over a certain number of *epochs*, each time consuming the entire data we just186# acquired in a nested training loop. Here, the ``sub_batch_size`` is different from the187# ``frames_per_batch`` here above: recall that we are working with a "batch of data"188# coming from our collector, which size is defined by ``frames_per_batch``, and that189# we will further split in smaller sub-batches during the inner training loop.190# The size of these sub-batches is controlled by ``sub_batch_size``.191#192sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop193num_epochs = 10 # optimization steps per batch of data collected194clip_epsilon = (1950.2 # clip value for PPO loss: see the equation in the intro for more context.196)197gamma = 0.99198lmbda = 0.95199entropy_eps = 1e-4200201######################################################################202# Define an environment203# ---------------------204#205# In RL, an *environment* is usually the way we refer to a simulator or a206# control system. Various libraries provide simulation environments for reinforcement207# learning, including Gymnasium (previously OpenAI Gym), DeepMind control suite, and208# many others.209# As a general library, TorchRL's goal is to provide an interchangeable interface210# to a large panel of RL simulators, allowing you to easily swap one environment211# with another. For example, creating a wrapped gym environment can be achieved with few characters:212#213214base_env = GymEnv("InvertedDoublePendulum-v4", device=device)215216######################################################################217# There are a few things to notice in this code: first, we created218# the environment by calling the ``GymEnv`` wrapper. If extra keyword arguments219# are passed, they will be transmitted to the ``gym.make`` method, hence covering220# the most common environment construction commands.221# Alternatively, one could also directly create a gym environment using ``gym.make(env_name, **kwargs)``222# and wrap it in a `GymWrapper` class.223#224# Also the ``device`` argument: for gym, this only controls the device where225# input action and observed states will be stored, but the execution will always226# be done on CPU. The reason for this is simply that gym does not support on-device227# execution, unless specified otherwise. For other libraries, we have control over228# the execution device and, as much as we can, we try to stay consistent in terms of229# storing and execution backends.230#231# Transforms232# ~~~~~~~~~~233#234# We will append some transforms to our environments to prepare the data for235# the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different236# approach, more similar to other pytorch domain libraries, through the use of transforms.237# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv`238# instance and append the sequence of transforms to it. The transformed environment will inherit239# the device and meta-data of the wrapped environment, and transform these depending on the sequence240# of transforms it contains.241#242# Normalization243# ~~~~~~~~~~~~~244#245# The first to encode is a normalization transform.246# As a rule of thumbs, it is preferable to have data that loosely247# match a unit Gaussian distribution: to obtain this, we will248# run a certain number of random steps in the environment and compute249# the summary statistics of these observations.250#251# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will252# convert double entries to single-precision numbers, ready to be read by the253# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before254# the environment is terminated. We will use this measure as a supplementary measure255# of performance.256#257# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict`258# to communicate. You could think of it as a python dictionary with some extra259# tensor features. In practice, this means that many modules we will be working260# with need to be told what key to read (``in_keys``) and what key to write261# (``out_keys``) in the ``tensordict`` they will receive. Usually, if ``out_keys``262# is omitted, it is assumed that the ``in_keys`` entries will be updated263# in-place. For our transforms, the only entry we are interested in is referred264# to as ``"observation"`` and our transform layers will be told to modify this265# entry and this entry only:266#267268env = TransformedEnv(269base_env,270Compose(271# normalize observations272ObservationNorm(in_keys=["observation"]),273DoubleToFloat(),274StepCounter(),275),276)277278######################################################################279# As you may have noticed, we have created a normalization layer but we did not280# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can281# automatically gather the summary statistics of our environment:282#283env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)284285######################################################################286# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a287# location and a scale that will be used to normalize the data.288#289# Let us do a little sanity check for the shape of our summary stats:290#291print("normalization constant shape:", env.transform[0].loc.shape)292293######################################################################294# An environment is not only defined by its simulator and transforms, but also295# by a series of metadata that describe what can be expected during its296# execution.297# For efficiency purposes, TorchRL is quite stringent when it comes to298# environment specs, but you can easily check that your environment specs are299# adequate.300# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and301# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits302# from it already take care of setting the proper specs for your environment so303# you should not have to care about this.304#305# Nevertheless, let's see a concrete example using our transformed306# environment by looking at its specs.307# There are three specs to look at: ``observation_spec`` which defines what308# is to be expected when executing an action in the environment,309# ``reward_spec`` which indicates the reward domain and finally the310# ``input_spec`` (which contains the ``action_spec``) and which represents311# everything an environment requires to execute a single step.312#313print("observation_spec:", env.observation_spec)314print("reward_spec:", env.reward_spec)315print("input_spec:", env.input_spec)316print("action_spec (as defined by input_spec):", env.action_spec)317318######################################################################319# the :func:`check_env_specs` function runs a small rollout and compares its output against the environment320# specs. If no error is raised, we can be confident that the specs are properly defined:321#322check_env_specs(env)323324######################################################################325# For fun, let's see what a simple random rollout looks like. You can326# call `env.rollout(n_steps)` and get an overview of what the environment inputs327# and outputs look like. Actions will automatically be drawn from the action spec328# domain, so you don't need to care about designing a random sampler.329#330# Typically, at each step, an RL environment receives an331# action as input, and outputs an observation, a reward and a done state. The332# observation may be composite, meaning that it could be composed of more than one333# tensor. This is not a problem for TorchRL, since the whole set of observations334# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout335# (for example, a sequence of environment steps and random action generations) over a given336# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape337# that matches this trajectory length:338#339rollout = env.rollout(3)340print("rollout of three steps:", rollout)341print("Shape of the rollout TensorDict:", rollout.batch_size)342343######################################################################344# Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps345# we ran it for. The ``"next"`` entry points to the data coming after the current step.346# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this347# may not be the case if we are using some specific transformations (for example, multi-step).348#349# Policy350# ------351#352# PPO utilizes a stochastic policy to handle exploration. This means that our353# neural network will have to output the parameters of a distribution, rather354# than a single value corresponding to the action taken.355#356# As the data is continuous, we use a Tanh-Normal distribution to respect the357# action space boundaries. TorchRL provides such distribution, and the only358# thing we need to care about is to build a neural network that outputs the359# right number of parameters for the policy to work with (a location, or mean,360# and a scale):361#362# .. math::363#364# f_{\theta}(\text{observation}) = \mu_{\theta}(\text{observation}), \sigma^{+}_{\theta}(\text{observation})365#366# The only extra-difficulty that is brought up here is to split our output in two367# equal parts and map the second to a strictly positive space.368#369# We design the policy in three steps:370#371# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``.372#373# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter).374#375# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it.376#377378actor_net = nn.Sequential(379nn.LazyLinear(num_cells, device=device),380nn.Tanh(),381nn.LazyLinear(num_cells, device=device),382nn.Tanh(),383nn.LazyLinear(num_cells, device=device),384nn.Tanh(),385nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),386NormalParamExtractor(),387)388389######################################################################390# To enable the policy to "talk" with the environment through the ``tensordict``391# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This392# class will simply ready the ``in_keys`` it is provided with and write the393# outputs in-place at the registered ``out_keys``.394#395policy_module = TensorDictModule(396actor_net, in_keys=["observation"], out_keys=["loc", "scale"]397)398399######################################################################400# We now need to build a distribution out of the location and scale of our401# normal distribution. To do so, we instruct the402# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`403# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale404# parameters. We also provide the minimum and maximum values of this405# distribution, which we gather from the environment specs.406#407# The name of the ``in_keys`` (and hence the name of the ``out_keys`` from408# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may409# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the410# ``loc`` and ``scale`` keyword arguments. That being said,411# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts412# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates413# what ``in_key`` string should be used for every keyword argument that is to be used.414#415policy_module = ProbabilisticActor(416module=policy_module,417spec=env.action_spec,418in_keys=["loc", "scale"],419distribution_class=TanhNormal,420distribution_kwargs={421"min": env.action_spec.space.low,422"max": env.action_spec.space.high,423},424return_log_prob=True,425# we'll need the log-prob for the numerator of the importance weights426)427428######################################################################429# Value network430# -------------431#432# The value network is a crucial component of the PPO algorithm, even though it433# won't be used at inference time. This module will read the observations and434# return an estimation of the discounted return for the following trajectory.435# This allows us to amortize learning by relying on the some utility estimation436# that is learned on-the-fly during training. Our value network share the same437# structure as the policy, but for simplicity we assign it its own set of438# parameters.439#440value_net = nn.Sequential(441nn.LazyLinear(num_cells, device=device),442nn.Tanh(),443nn.LazyLinear(num_cells, device=device),444nn.Tanh(),445nn.LazyLinear(num_cells, device=device),446nn.Tanh(),447nn.LazyLinear(1, device=device),448)449450value_module = ValueOperator(451module=value_net,452in_keys=["observation"],453)454455######################################################################456# let's try our policy and value modules. As we said earlier, the usage of457# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output458# of the environment to run these modules, as they know what information to read459# and where to write it:460#461print("Running policy:", policy_module(env.reset()))462print("Running value:", value_module(env.reset()))463464######################################################################465# Data collector466# --------------467#468# TorchRL provides a set of `DataCollector classes <https://pytorch.org/rl/reference/collectors.html>`__.469# Briefly, these classes execute three operations: reset an environment,470# compute an action given the latest observation, execute a step in the environment,471# and repeat the last two steps until the environment signals a stop (or reaches472# a done state).473#474# They allow you to control how many frames to collect at each iteration475# (through the ``frames_per_batch`` parameter),476# when to reset the environment (through the ``max_frames_per_traj`` argument),477# on which ``device`` the policy should be executed, etc. They are also478# designed to work efficiently with batched and multiprocessed environments.479#480# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`:481# it is an iterator that you can use to get batches of data of a given length, and482# that will stop once a total number of frames (``total_frames``) have been483# collected.484# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and485# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute486# the same operations in synchronous and asynchronous manner over a487# set of multiprocessed workers.488#489# As for the policy and environment before, the data collector will return490# :class:`~tensordict.TensorDict` instances with a total number of elements that will491# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the492# training loop allows you to write data loading pipelines493# that are 100% oblivious to the actual specificities of the rollout content.494#495collector = SyncDataCollector(496env,497policy_module,498frames_per_batch=frames_per_batch,499total_frames=total_frames,500split_trajs=False,501device=device,502)503504######################################################################505# Replay buffer506# -------------507#508# Replay buffers are a common building piece of off-policy RL algorithms.509# In on-policy contexts, a replay buffer is refilled every time a batch of510# data is collected, and its data is repeatedly consumed for a certain number511# of epochs.512#513# TorchRL's replay buffers are built using a common container514# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components515# of the buffer: a storage, a writer, a sampler and possibly some transforms.516# Only the storage (which indicates the replay buffer capacity) is mandatory.517# We also specify a sampler without repetition to avoid sampling multiple times518# the same item in one epoch.519# Using a replay buffer for PPO is not mandatory and we could simply520# sample the sub-batches from the collected batch, but using these classes521# make it easy for us to build the inner training loop in a reproducible way.522#523524replay_buffer = ReplayBuffer(525storage=LazyTensorStorage(max_size=frames_per_batch),526sampler=SamplerWithoutReplacement(),527)528529######################################################################530# Loss function531# -------------532#533# The PPO loss can be directly imported from TorchRL for convenience using the534# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO:535# it hides away the mathematical operations of PPO and the control flow that536# goes with it.537#538# PPO requires some "advantage estimation" to be computed. In short, an advantage539# is a value that reflects an expectancy over the return value while dealing with540# the bias / variance tradeoff.541# To compute the advantage, one just needs to (1) build the advantage module, which542# utilizes our value operator, and (2) pass each batch of data through it before each543# epoch.544# The GAE module will update the input ``tensordict`` with new ``"advantage"`` and545# ``"value_target"`` entries.546# The ``"value_target"`` is a gradient-free tensor that represents the empirical547# value that the value network should represent with the input observation.548# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to549# return the policy and value losses.550#551552advantage_module = GAE(553gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True554)555556loss_module = ClipPPOLoss(557actor_network=policy_module,558critic_network=value_module,559clip_epsilon=clip_epsilon,560entropy_bonus=bool(entropy_eps),561entropy_coef=entropy_eps,562# these keys match by default but we set this for completeness563critic_coef=1.0,564loss_critic_type="smooth_l1",565)566567optim = torch.optim.Adam(loss_module.parameters(), lr)568scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(569optim, total_frames // frames_per_batch, 0.0570)571572######################################################################573# Training loop574# -------------575# We now have all the pieces needed to code our training loop.576# The steps include:577#578# * Collect data579#580# * Compute advantage581#582# * Loop over the collected to compute loss values583# * Back propagate584# * Optimize585# * Repeat586#587# * Repeat588#589# * Repeat590#591592593logs = defaultdict(list)594pbar = tqdm(total=total_frames)595eval_str = ""596597# We iterate over the collector until it reaches the total number of frames it was598# designed to collect:599for i, tensordict_data in enumerate(collector):600# we now have a batch of data to work with. Let's learn something from it.601for _ in range(num_epochs):602# We'll need an "advantage" signal to make PPO work.603# We re-compute it at each epoch as its value depends on the value604# network which is updated in the inner loop.605advantage_module(tensordict_data)606data_view = tensordict_data.reshape(-1)607replay_buffer.extend(data_view.cpu())608for _ in range(frames_per_batch // sub_batch_size):609subdata = replay_buffer.sample(sub_batch_size)610loss_vals = loss_module(subdata.to(device))611loss_value = (612loss_vals["loss_objective"]613+ loss_vals["loss_critic"]614+ loss_vals["loss_entropy"]615)616617# Optimization: backward, grad clipping and optimization step618loss_value.backward()619# this is not strictly mandatory but it's good practice to keep620# your gradient norm bounded621torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)622optim.step()623optim.zero_grad()624625logs["reward"].append(tensordict_data["next", "reward"].mean().item())626pbar.update(tensordict_data.numel())627cum_reward_str = (628f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"629)630logs["step_count"].append(tensordict_data["step_count"].max().item())631stepcount_str = f"step count (max): {logs['step_count'][-1]}"632logs["lr"].append(optim.param_groups[0]["lr"])633lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"634if i % 10 == 0:635# We evaluate the policy once every 10 batches of data.636# Evaluation is rather simple: execute the policy without exploration637# (take the expected value of the action distribution) for a given638# number of steps (1000, which is our ``env`` horizon).639# The ``rollout`` method of the ``env`` can take a policy as argument:640# it will then execute this policy at each step.641with set_exploration_type(ExplorationType.MEAN), torch.no_grad():642# execute a rollout with the trained policy643eval_rollout = env.rollout(1000, policy_module)644logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())645logs["eval reward (sum)"].append(646eval_rollout["next", "reward"].sum().item()647)648logs["eval step_count"].append(eval_rollout["step_count"].max().item())649eval_str = (650f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "651f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "652f"eval step-count: {logs['eval step_count'][-1]}"653)654del eval_rollout655pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))656657# We're also using a learning rate scheduler. Like the gradient clipping,658# this is a nice-to-have but nothing necessary for PPO to work.659scheduler.step()660661######################################################################662# Results663# -------664#665# Before the 1M step cap is reached, the algorithm should have reached a max666# step count of 1000 steps, which is the maximum number of steps before the667# trajectory is truncated.668#669plt.figure(figsize=(10, 10))670plt.subplot(2, 2, 1)671plt.plot(logs["reward"])672plt.title("training rewards (average)")673plt.subplot(2, 2, 2)674plt.plot(logs["step_count"])675plt.title("Max step count (training)")676plt.subplot(2, 2, 3)677plt.plot(logs["eval reward (sum)"])678plt.title("Return (test)")679plt.subplot(2, 2, 4)680plt.plot(logs["eval step_count"])681plt.title("Max step count (test)")682plt.show()683684######################################################################685# Conclusion and next steps686# -------------------------687#688# In this tutorial, we have learned:689#690# 1. How to create and customize an environment with :py:mod:`torchrl`;691# 2. How to write a model and a loss function;692# 3. How to set up a typical training loop.693#694# If you want to experiment with this tutorial a bit more, you can apply the following modifications:695#696# * From an efficiency perspective,697# we could run several simulations in parallel to speed up data collection.698# Check :class:`~torchrl.envs.ParallelEnv` for further information.699#700# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to701# the environment after asking for rendering to get a visual rendering of the702# inverted pendulum in action. Check :py:mod:`torchrl.record` to703# know more.704#705706707