CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/reinforcement_q_learning.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Reinforcement Learning (DQN) Tutorial
4
=====================================
5
**Author**: `Adam Paszke <https://github.com/apaszke>`_
6
`Mark Towers <https://github.com/pseudo-rnd-thoughts>`_
7
8
9
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
10
on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
11
12
You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper
13
14
**Task**
15
16
The agent has to decide between two actions - moving the cart left or
17
right - so that the pole attached to it stays upright. You can find more
18
information about the environment and other more challenging environments at
19
`Gymnasium's website <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`__.
20
21
.. figure:: /_static/img/cartpole.gif
22
:alt: CartPole
23
24
CartPole
25
26
As the agent observes the current state of the environment and chooses
27
an action, the environment *transitions* to a new state, and also
28
returns a reward that indicates the consequences of the action. In this
29
task, rewards are +1 for every incremental timestep and the environment
30
terminates if the pole falls over too far or the cart moves more than 2.4
31
units away from center. This means better performing scenarios will run
32
for longer duration, accumulating larger return.
33
34
The CartPole task is designed so that the inputs to the agent are 4 real
35
values representing the environment state (position, velocity, etc.).
36
We take these 4 inputs without any scaling and pass them through a
37
small fully-connected network with 2 outputs, one for each action.
38
The network is trained to predict the expected value for each action,
39
given the input state. The action with the highest expected value is
40
then chosen.
41
42
43
**Packages**
44
45
46
First, let's import needed packages. Firstly, we need
47
`gymnasium <https://gymnasium.farama.org/>`__ for the environment,
48
installed by using `pip`. This is a fork of the original OpenAI
49
Gym project and maintained by the same team since Gym v0.19.
50
If you are running this in Google Colab, run:
51
52
.. code-block:: bash
53
54
%%bash
55
pip3 install gymnasium[classic_control]
56
57
We'll also use the following from PyTorch:
58
59
- neural networks (``torch.nn``)
60
- optimization (``torch.optim``)
61
- automatic differentiation (``torch.autograd``)
62
63
"""
64
65
import gymnasium as gym
66
import math
67
import random
68
import matplotlib
69
import matplotlib.pyplot as plt
70
from collections import namedtuple, deque
71
from itertools import count
72
73
import torch
74
import torch.nn as nn
75
import torch.optim as optim
76
import torch.nn.functional as F
77
78
env = gym.make("CartPole-v1")
79
80
# set up matplotlib
81
is_ipython = 'inline' in matplotlib.get_backend()
82
if is_ipython:
83
from IPython import display
84
85
plt.ion()
86
87
# if GPU is to be used
88
device = torch.device(
89
"cuda" if torch.cuda.is_available() else
90
"mps" if torch.backends.mps.is_available() else
91
"cpu"
92
)
93
94
95
######################################################################
96
# Replay Memory
97
# -------------
98
#
99
# We'll be using experience replay memory for training our DQN. It stores
100
# the transitions that the agent observes, allowing us to reuse this data
101
# later. By sampling from it randomly, the transitions that build up a
102
# batch are decorrelated. It has been shown that this greatly stabilizes
103
# and improves the DQN training procedure.
104
#
105
# For this, we're going to need two classes:
106
#
107
# - ``Transition`` - a named tuple representing a single transition in
108
# our environment. It essentially maps (state, action) pairs
109
# to their (next_state, reward) result, with the state being the
110
# screen difference image as described later on.
111
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
112
# transitions observed recently. It also implements a ``.sample()``
113
# method for selecting a random batch of transitions for training.
114
#
115
116
Transition = namedtuple('Transition',
117
('state', 'action', 'next_state', 'reward'))
118
119
120
class ReplayMemory(object):
121
122
def __init__(self, capacity):
123
self.memory = deque([], maxlen=capacity)
124
125
def push(self, *args):
126
"""Save a transition"""
127
self.memory.append(Transition(*args))
128
129
def sample(self, batch_size):
130
return random.sample(self.memory, batch_size)
131
132
def __len__(self):
133
return len(self.memory)
134
135
136
######################################################################
137
# Now, let's define our model. But first, let's quickly recap what a DQN is.
138
#
139
# DQN algorithm
140
# -------------
141
#
142
# Our environment is deterministic, so all equations presented here are
143
# also formulated deterministically for the sake of simplicity. In the
144
# reinforcement learning literature, they would also contain expectations
145
# over stochastic transitions in the environment.
146
#
147
# Our aim will be to train a policy that tries to maximize the discounted,
148
# cumulative reward
149
# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where
150
# :math:`R_{t_0}` is also known as the *return*. The discount,
151
# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`
152
# that ensures the sum converges. A lower :math:`\gamma` makes
153
# rewards from the uncertain far future less important for our agent
154
# than the ones in the near future that it can be fairly confident
155
# about. It also encourages agents to collect reward closer in time
156
# than equivalent rewards that are temporally far away in the future.
157
#
158
# The main idea behind Q-learning is that if we had a function
159
# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell
160
# us what our return would be, if we were to take an action in a given
161
# state, then we could easily construct a policy that maximizes our
162
# rewards:
163
#
164
# .. math:: \pi^*(s) = \arg\!\max_a \ Q^*(s, a)
165
#
166
# However, we don't know everything about the world, so we don't have
167
# access to :math:`Q^*`. But, since neural networks are universal function
168
# approximators, we can simply create one and train it to resemble
169
# :math:`Q^*`.
170
#
171
# For our training update rule, we'll use a fact that every :math:`Q`
172
# function for some policy obeys the Bellman equation:
173
#
174
# .. math:: Q^{\pi}(s, a) = r + \gamma Q^{\pi}(s', \pi(s'))
175
#
176
# The difference between the two sides of the equality is known as the
177
# temporal difference error, :math:`\delta`:
178
#
179
# .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))
180
#
181
# To minimize this error, we will use the `Huber
182
# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts
183
# like the mean squared error when the error is small, but like the mean
184
# absolute error when the error is large - this makes it more robust to
185
# outliers when the estimates of :math:`Q` are very noisy. We calculate
186
# this over a batch of transitions, :math:`B`, sampled from the replay
187
# memory:
188
#
189
# .. math::
190
#
191
# \mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)
192
#
193
# .. math::
194
#
195
# \text{where} \quad \mathcal{L}(\delta) = \begin{cases}
196
# \frac{1}{2}{\delta^2} & \text{for } |\delta| \le 1, \\
197
# |\delta| - \frac{1}{2} & \text{otherwise.}
198
# \end{cases}
199
#
200
# Q-network
201
# ^^^^^^^^^
202
#
203
# Our model will be a feed forward neural network that takes in the
204
# difference between the current and previous screen patches. It has two
205
# outputs, representing :math:`Q(s, \mathrm{left})` and
206
# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the
207
# network). In effect, the network is trying to predict the *expected return* of
208
# taking each action given the current input.
209
#
210
211
class DQN(nn.Module):
212
213
def __init__(self, n_observations, n_actions):
214
super(DQN, self).__init__()
215
self.layer1 = nn.Linear(n_observations, 128)
216
self.layer2 = nn.Linear(128, 128)
217
self.layer3 = nn.Linear(128, n_actions)
218
219
# Called with either one element to determine next action, or a batch
220
# during optimization. Returns tensor([[left0exp,right0exp]...]).
221
def forward(self, x):
222
x = F.relu(self.layer1(x))
223
x = F.relu(self.layer2(x))
224
return self.layer3(x)
225
226
227
######################################################################
228
# Training
229
# --------
230
#
231
# Hyperparameters and utilities
232
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
233
# This cell instantiates our model and its optimizer, and defines some
234
# utilities:
235
#
236
# - ``select_action`` - will select an action according to an epsilon
237
# greedy policy. Simply put, we'll sometimes use our model for choosing
238
# the action, and sometimes we'll just sample one uniformly. The
239
# probability of choosing a random action will start at ``EPS_START``
240
# and will decay exponentially towards ``EPS_END``. ``EPS_DECAY``
241
# controls the rate of the decay.
242
# - ``plot_durations`` - a helper for plotting the duration of episodes,
243
# along with an average over the last 100 episodes (the measure used in
244
# the official evaluations). The plot will be underneath the cell
245
# containing the main training loop, and will update after every
246
# episode.
247
#
248
249
# BATCH_SIZE is the number of transitions sampled from the replay buffer
250
# GAMMA is the discount factor as mentioned in the previous section
251
# EPS_START is the starting value of epsilon
252
# EPS_END is the final value of epsilon
253
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
254
# TAU is the update rate of the target network
255
# LR is the learning rate of the ``AdamW`` optimizer
256
BATCH_SIZE = 128
257
GAMMA = 0.99
258
EPS_START = 0.9
259
EPS_END = 0.05
260
EPS_DECAY = 1000
261
TAU = 0.005
262
LR = 1e-4
263
264
# Get number of actions from gym action space
265
n_actions = env.action_space.n
266
# Get the number of state observations
267
state, info = env.reset()
268
n_observations = len(state)
269
270
policy_net = DQN(n_observations, n_actions).to(device)
271
target_net = DQN(n_observations, n_actions).to(device)
272
target_net.load_state_dict(policy_net.state_dict())
273
274
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
275
memory = ReplayMemory(10000)
276
277
278
steps_done = 0
279
280
281
def select_action(state):
282
global steps_done
283
sample = random.random()
284
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
285
math.exp(-1. * steps_done / EPS_DECAY)
286
steps_done += 1
287
if sample > eps_threshold:
288
with torch.no_grad():
289
# t.max(1) will return the largest column value of each row.
290
# second column on max result is index of where max element was
291
# found, so we pick action with the larger expected reward.
292
return policy_net(state).max(1).indices.view(1, 1)
293
else:
294
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
295
296
297
episode_durations = []
298
299
300
def plot_durations(show_result=False):
301
plt.figure(1)
302
durations_t = torch.tensor(episode_durations, dtype=torch.float)
303
if show_result:
304
plt.title('Result')
305
else:
306
plt.clf()
307
plt.title('Training...')
308
plt.xlabel('Episode')
309
plt.ylabel('Duration')
310
plt.plot(durations_t.numpy())
311
# Take 100 episode averages and plot them too
312
if len(durations_t) >= 100:
313
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
314
means = torch.cat((torch.zeros(99), means))
315
plt.plot(means.numpy())
316
317
plt.pause(0.001) # pause a bit so that plots are updated
318
if is_ipython:
319
if not show_result:
320
display.display(plt.gcf())
321
display.clear_output(wait=True)
322
else:
323
display.display(plt.gcf())
324
325
326
######################################################################
327
# Training loop
328
# ^^^^^^^^^^^^^
329
#
330
# Finally, the code for training our model.
331
#
332
# Here, you can find an ``optimize_model`` function that performs a
333
# single step of the optimization. It first samples a batch, concatenates
334
# all the tensors into a single one, computes :math:`Q(s_t, a_t)` and
335
# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
336
# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal
337
# state. We also use a target network to compute :math:`V(s_{t+1})` for
338
# added stability. The target network is updated at every step with a
339
# `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by
340
# the hyperparameter ``TAU``, which was previously defined.
341
#
342
343
def optimize_model():
344
if len(memory) < BATCH_SIZE:
345
return
346
transitions = memory.sample(BATCH_SIZE)
347
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
348
# detailed explanation). This converts batch-array of Transitions
349
# to Transition of batch-arrays.
350
batch = Transition(*zip(*transitions))
351
352
# Compute a mask of non-final states and concatenate the batch elements
353
# (a final state would've been the one after which simulation ended)
354
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
355
batch.next_state)), device=device, dtype=torch.bool)
356
non_final_next_states = torch.cat([s for s in batch.next_state
357
if s is not None])
358
state_batch = torch.cat(batch.state)
359
action_batch = torch.cat(batch.action)
360
reward_batch = torch.cat(batch.reward)
361
362
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
363
# columns of actions taken. These are the actions which would've been taken
364
# for each batch state according to policy_net
365
state_action_values = policy_net(state_batch).gather(1, action_batch)
366
367
# Compute V(s_{t+1}) for all next states.
368
# Expected values of actions for non_final_next_states are computed based
369
# on the "older" target_net; selecting their best reward with max(1).values
370
# This is merged based on the mask, such that we'll have either the expected
371
# state value or 0 in case the state was final.
372
next_state_values = torch.zeros(BATCH_SIZE, device=device)
373
with torch.no_grad():
374
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
375
# Compute the expected Q values
376
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
377
378
# Compute Huber loss
379
criterion = nn.SmoothL1Loss()
380
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
381
382
# Optimize the model
383
optimizer.zero_grad()
384
loss.backward()
385
# In-place gradient clipping
386
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
387
optimizer.step()
388
389
390
######################################################################
391
#
392
# Below, you can find the main training loop. At the beginning we reset
393
# the environment and obtain the initial ``state`` Tensor. Then, we sample
394
# an action, execute it, observe the next state and the reward (always
395
# 1), and optimize our model once. When the episode ends (our model
396
# fails), we restart the loop.
397
#
398
# Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50
399
# episodes are scheduled so training does not take too long. However, 50
400
# episodes is insufficient for to observe good performance on CartPole.
401
# You should see the model constantly achieve 500 steps within 600 training
402
# episodes. Training RL agents can be a noisy process, so restarting training
403
# can produce better results if convergence is not observed.
404
#
405
406
if torch.cuda.is_available() or torch.backends.mps.is_available():
407
num_episodes = 600
408
else:
409
num_episodes = 50
410
411
for i_episode in range(num_episodes):
412
# Initialize the environment and get its state
413
state, info = env.reset()
414
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
415
for t in count():
416
action = select_action(state)
417
observation, reward, terminated, truncated, _ = env.step(action.item())
418
reward = torch.tensor([reward], device=device)
419
done = terminated or truncated
420
421
if terminated:
422
next_state = None
423
else:
424
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
425
426
# Store the transition in memory
427
memory.push(state, action, next_state, reward)
428
429
# Move to the next state
430
state = next_state
431
432
# Perform one step of the optimization (on the policy network)
433
optimize_model()
434
435
# Soft update of the target network's weights
436
# θ′ ← τ θ + (1 −τ )θ′
437
target_net_state_dict = target_net.state_dict()
438
policy_net_state_dict = policy_net.state_dict()
439
for key in policy_net_state_dict:
440
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
441
target_net.load_state_dict(target_net_state_dict)
442
443
if done:
444
episode_durations.append(t + 1)
445
plot_durations()
446
break
447
448
print('Complete')
449
plot_durations(show_result=True)
450
plt.ioff()
451
plt.show()
452
453
######################################################################
454
# Here is the diagram that illustrates the overall resulting data flow.
455
#
456
# .. figure:: /_static/img/reinforcement_learning_diagram.jpg
457
#
458
# Actions are chosen either randomly or based on a policy, getting the next
459
# step sample from the gym environment. We record the results in the
460
# replay memory and also run optimization step on every iteration.
461
# Optimization picks a random batch from the replay memory to do training of the
462
# new policy. The "older" target_net is also used in optimization to compute the
463
# expected Q values. A soft update of its weights are performed at every step.
464
#
465
466