Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/reinforcement_q_learning.py
Views: 712
# -*- coding: utf-8 -*-1"""2Reinforcement Learning (DQN) Tutorial3=====================================4**Author**: `Adam Paszke <https://github.com/apaszke>`_5`Mark Towers <https://github.com/pseudo-rnd-thoughts>`_678This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent9on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.1011You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper1213**Task**1415The agent has to decide between two actions - moving the cart left or16right - so that the pole attached to it stays upright. You can find more17information about the environment and other more challenging environments at18`Gymnasium's website <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`__.1920.. figure:: /_static/img/cartpole.gif21:alt: CartPole2223CartPole2425As the agent observes the current state of the environment and chooses26an action, the environment *transitions* to a new state, and also27returns a reward that indicates the consequences of the action. In this28task, rewards are +1 for every incremental timestep and the environment29terminates if the pole falls over too far or the cart moves more than 2.430units away from center. This means better performing scenarios will run31for longer duration, accumulating larger return.3233The CartPole task is designed so that the inputs to the agent are 4 real34values representing the environment state (position, velocity, etc.).35We take these 4 inputs without any scaling and pass them through a36small fully-connected network with 2 outputs, one for each action.37The network is trained to predict the expected value for each action,38given the input state. The action with the highest expected value is39then chosen.404142**Packages**434445First, let's import needed packages. Firstly, we need46`gymnasium <https://gymnasium.farama.org/>`__ for the environment,47installed by using `pip`. This is a fork of the original OpenAI48Gym project and maintained by the same team since Gym v0.19.49If you are running this in Google Colab, run:5051.. code-block:: bash5253%%bash54pip3 install gymnasium[classic_control]5556We'll also use the following from PyTorch:5758- neural networks (``torch.nn``)59- optimization (``torch.optim``)60- automatic differentiation (``torch.autograd``)6162"""6364import gymnasium as gym65import math66import random67import matplotlib68import matplotlib.pyplot as plt69from collections import namedtuple, deque70from itertools import count7172import torch73import torch.nn as nn74import torch.optim as optim75import torch.nn.functional as F7677env = gym.make("CartPole-v1")7879# set up matplotlib80is_ipython = 'inline' in matplotlib.get_backend()81if is_ipython:82from IPython import display8384plt.ion()8586# if GPU is to be used87device = torch.device(88"cuda" if torch.cuda.is_available() else89"mps" if torch.backends.mps.is_available() else90"cpu"91)929394######################################################################95# Replay Memory96# -------------97#98# We'll be using experience replay memory for training our DQN. It stores99# the transitions that the agent observes, allowing us to reuse this data100# later. By sampling from it randomly, the transitions that build up a101# batch are decorrelated. It has been shown that this greatly stabilizes102# and improves the DQN training procedure.103#104# For this, we're going to need two classes:105#106# - ``Transition`` - a named tuple representing a single transition in107# our environment. It essentially maps (state, action) pairs108# to their (next_state, reward) result, with the state being the109# screen difference image as described later on.110# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the111# transitions observed recently. It also implements a ``.sample()``112# method for selecting a random batch of transitions for training.113#114115Transition = namedtuple('Transition',116('state', 'action', 'next_state', 'reward'))117118119class ReplayMemory(object):120121def __init__(self, capacity):122self.memory = deque([], maxlen=capacity)123124def push(self, *args):125"""Save a transition"""126self.memory.append(Transition(*args))127128def sample(self, batch_size):129return random.sample(self.memory, batch_size)130131def __len__(self):132return len(self.memory)133134135######################################################################136# Now, let's define our model. But first, let's quickly recap what a DQN is.137#138# DQN algorithm139# -------------140#141# Our environment is deterministic, so all equations presented here are142# also formulated deterministically for the sake of simplicity. In the143# reinforcement learning literature, they would also contain expectations144# over stochastic transitions in the environment.145#146# Our aim will be to train a policy that tries to maximize the discounted,147# cumulative reward148# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where149# :math:`R_{t_0}` is also known as the *return*. The discount,150# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`151# that ensures the sum converges. A lower :math:`\gamma` makes152# rewards from the uncertain far future less important for our agent153# than the ones in the near future that it can be fairly confident154# about. It also encourages agents to collect reward closer in time155# than equivalent rewards that are temporally far away in the future.156#157# The main idea behind Q-learning is that if we had a function158# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell159# us what our return would be, if we were to take an action in a given160# state, then we could easily construct a policy that maximizes our161# rewards:162#163# .. math:: \pi^*(s) = \arg\!\max_a \ Q^*(s, a)164#165# However, we don't know everything about the world, so we don't have166# access to :math:`Q^*`. But, since neural networks are universal function167# approximators, we can simply create one and train it to resemble168# :math:`Q^*`.169#170# For our training update rule, we'll use a fact that every :math:`Q`171# function for some policy obeys the Bellman equation:172#173# .. math:: Q^{\pi}(s, a) = r + \gamma Q^{\pi}(s', \pi(s'))174#175# The difference between the two sides of the equality is known as the176# temporal difference error, :math:`\delta`:177#178# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))179#180# To minimize this error, we will use the `Huber181# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts182# like the mean squared error when the error is small, but like the mean183# absolute error when the error is large - this makes it more robust to184# outliers when the estimates of :math:`Q` are very noisy. We calculate185# this over a batch of transitions, :math:`B`, sampled from the replay186# memory:187#188# .. math::189#190# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)191#192# .. math::193#194# \text{where} \quad \mathcal{L}(\delta) = \begin{cases}195# \frac{1}{2}{\delta^2} & \text{for } |\delta| \le 1, \\196# |\delta| - \frac{1}{2} & \text{otherwise.}197# \end{cases}198#199# Q-network200# ^^^^^^^^^201#202# Our model will be a feed forward neural network that takes in the203# difference between the current and previous screen patches. It has two204# outputs, representing :math:`Q(s, \mathrm{left})` and205# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the206# network). In effect, the network is trying to predict the *expected return* of207# taking each action given the current input.208#209210class DQN(nn.Module):211212def __init__(self, n_observations, n_actions):213super(DQN, self).__init__()214self.layer1 = nn.Linear(n_observations, 128)215self.layer2 = nn.Linear(128, 128)216self.layer3 = nn.Linear(128, n_actions)217218# Called with either one element to determine next action, or a batch219# during optimization. Returns tensor([[left0exp,right0exp]...]).220def forward(self, x):221x = F.relu(self.layer1(x))222x = F.relu(self.layer2(x))223return self.layer3(x)224225226######################################################################227# Training228# --------229#230# Hyperparameters and utilities231# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^232# This cell instantiates our model and its optimizer, and defines some233# utilities:234#235# - ``select_action`` - will select an action according to an epsilon236# greedy policy. Simply put, we'll sometimes use our model for choosing237# the action, and sometimes we'll just sample one uniformly. The238# probability of choosing a random action will start at ``EPS_START``239# and will decay exponentially towards ``EPS_END``. ``EPS_DECAY``240# controls the rate of the decay.241# - ``plot_durations`` - a helper for plotting the duration of episodes,242# along with an average over the last 100 episodes (the measure used in243# the official evaluations). The plot will be underneath the cell244# containing the main training loop, and will update after every245# episode.246#247248# BATCH_SIZE is the number of transitions sampled from the replay buffer249# GAMMA is the discount factor as mentioned in the previous section250# EPS_START is the starting value of epsilon251# EPS_END is the final value of epsilon252# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay253# TAU is the update rate of the target network254# LR is the learning rate of the ``AdamW`` optimizer255BATCH_SIZE = 128256GAMMA = 0.99257EPS_START = 0.9258EPS_END = 0.05259EPS_DECAY = 1000260TAU = 0.005261LR = 1e-4262263# Get number of actions from gym action space264n_actions = env.action_space.n265# Get the number of state observations266state, info = env.reset()267n_observations = len(state)268269policy_net = DQN(n_observations, n_actions).to(device)270target_net = DQN(n_observations, n_actions).to(device)271target_net.load_state_dict(policy_net.state_dict())272273optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)274memory = ReplayMemory(10000)275276277steps_done = 0278279280def select_action(state):281global steps_done282sample = random.random()283eps_threshold = EPS_END + (EPS_START - EPS_END) * \284math.exp(-1. * steps_done / EPS_DECAY)285steps_done += 1286if sample > eps_threshold:287with torch.no_grad():288# t.max(1) will return the largest column value of each row.289# second column on max result is index of where max element was290# found, so we pick action with the larger expected reward.291return policy_net(state).max(1).indices.view(1, 1)292else:293return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)294295296episode_durations = []297298299def plot_durations(show_result=False):300plt.figure(1)301durations_t = torch.tensor(episode_durations, dtype=torch.float)302if show_result:303plt.title('Result')304else:305plt.clf()306plt.title('Training...')307plt.xlabel('Episode')308plt.ylabel('Duration')309plt.plot(durations_t.numpy())310# Take 100 episode averages and plot them too311if len(durations_t) >= 100:312means = durations_t.unfold(0, 100, 1).mean(1).view(-1)313means = torch.cat((torch.zeros(99), means))314plt.plot(means.numpy())315316plt.pause(0.001) # pause a bit so that plots are updated317if is_ipython:318if not show_result:319display.display(plt.gcf())320display.clear_output(wait=True)321else:322display.display(plt.gcf())323324325######################################################################326# Training loop327# ^^^^^^^^^^^^^328#329# Finally, the code for training our model.330#331# Here, you can find an ``optimize_model`` function that performs a332# single step of the optimization. It first samples a batch, concatenates333# all the tensors into a single one, computes :math:`Q(s_t, a_t)` and334# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our335# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal336# state. We also use a target network to compute :math:`V(s_{t+1})` for337# added stability. The target network is updated at every step with a338# `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by339# the hyperparameter ``TAU``, which was previously defined.340#341342def optimize_model():343if len(memory) < BATCH_SIZE:344return345transitions = memory.sample(BATCH_SIZE)346# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for347# detailed explanation). This converts batch-array of Transitions348# to Transition of batch-arrays.349batch = Transition(*zip(*transitions))350351# Compute a mask of non-final states and concatenate the batch elements352# (a final state would've been the one after which simulation ended)353non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,354batch.next_state)), device=device, dtype=torch.bool)355non_final_next_states = torch.cat([s for s in batch.next_state356if s is not None])357state_batch = torch.cat(batch.state)358action_batch = torch.cat(batch.action)359reward_batch = torch.cat(batch.reward)360361# Compute Q(s_t, a) - the model computes Q(s_t), then we select the362# columns of actions taken. These are the actions which would've been taken363# for each batch state according to policy_net364state_action_values = policy_net(state_batch).gather(1, action_batch)365366# Compute V(s_{t+1}) for all next states.367# Expected values of actions for non_final_next_states are computed based368# on the "older" target_net; selecting their best reward with max(1).values369# This is merged based on the mask, such that we'll have either the expected370# state value or 0 in case the state was final.371next_state_values = torch.zeros(BATCH_SIZE, device=device)372with torch.no_grad():373next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values374# Compute the expected Q values375expected_state_action_values = (next_state_values * GAMMA) + reward_batch376377# Compute Huber loss378criterion = nn.SmoothL1Loss()379loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))380381# Optimize the model382optimizer.zero_grad()383loss.backward()384# In-place gradient clipping385torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)386optimizer.step()387388389######################################################################390#391# Below, you can find the main training loop. At the beginning we reset392# the environment and obtain the initial ``state`` Tensor. Then, we sample393# an action, execute it, observe the next state and the reward (always394# 1), and optimize our model once. When the episode ends (our model395# fails), we restart the loop.396#397# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50398# episodes are scheduled so training does not take too long. However, 50399# episodes is insufficient for to observe good performance on CartPole.400# You should see the model constantly achieve 500 steps within 600 training401# episodes. Training RL agents can be a noisy process, so restarting training402# can produce better results if convergence is not observed.403#404405if torch.cuda.is_available() or torch.backends.mps.is_available():406num_episodes = 600407else:408num_episodes = 50409410for i_episode in range(num_episodes):411# Initialize the environment and get its state412state, info = env.reset()413state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)414for t in count():415action = select_action(state)416observation, reward, terminated, truncated, _ = env.step(action.item())417reward = torch.tensor([reward], device=device)418done = terminated or truncated419420if terminated:421next_state = None422else:423next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)424425# Store the transition in memory426memory.push(state, action, next_state, reward)427428# Move to the next state429state = next_state430431# Perform one step of the optimization (on the policy network)432optimize_model()433434# Soft update of the target network's weights435# θ′ ← τ θ + (1 −τ )θ′436target_net_state_dict = target_net.state_dict()437policy_net_state_dict = policy_net.state_dict()438for key in policy_net_state_dict:439target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)440target_net.load_state_dict(target_net_state_dict)441442if done:443episode_durations.append(t + 1)444plot_durations()445break446447print('Complete')448plot_durations(show_result=True)449plt.ioff()450plt.show()451452######################################################################453# Here is the diagram that illustrates the overall resulting data flow.454#455# .. figure:: /_static/img/reinforcement_learning_diagram.jpg456#457# Actions are chosen either randomly or based on a policy, getting the next458# step sample from the gym environment. We record the results in the459# replay memory and also run optimization step on every iteration.460# Optimization picks a random batch from the replay memory to do training of the461# new policy. The "older" target_net is also used in optimization to compute the462# expected Q values. A soft update of its weights are performed at every step.463#464465466