Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/rl/ddpg_pendulum.py
3507 views
1
"""
2
Title: Deep Deterministic Policy Gradient (DDPG)
3
Author: [amifunny](https://github.com/amifunny)
4
Date created: 2020/06/04
5
Last modified: 2024/03/23
6
Description: Implementing DDPG algorithm on the Inverted Pendulum Problem.
7
Accelerator: None
8
"""
9
10
"""
11
## Introduction
12
13
**Deep Deterministic Policy Gradient (DDPG)** is a model-free off-policy algorithm for
14
learning continuous actions.
15
16
It combines ideas from DPG (Deterministic Policy Gradient) and DQN (Deep Q-Network).
17
It uses Experience Replay and slow-learning target networks from DQN, and it is based on
18
DPG, which can operate over continuous action spaces.
19
20
This tutorial closely follow this paper -
21
[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971)
22
23
## Problem
24
25
We are trying to solve the classic **Inverted Pendulum** control problem.
26
In this setting, we can take only two actions: swing left or swing right.
27
28
What make this problem challenging for Q-Learning Algorithms is that actions
29
are **continuous** instead of being **discrete**. That is, instead of using two
30
discrete actions like `-1` or `+1`, we have to select from infinite actions
31
ranging from `-2` to `+2`.
32
33
## Quick theory
34
35
Just like the Actor-Critic method, we have two networks:
36
37
1. Actor - It proposes an action given a state.
38
2. Critic - It predicts if the action is good (positive value) or bad (negative value)
39
given a state and an action.
40
41
DDPG uses two more techniques not present in the original DQN:
42
43
**First, it uses two Target networks.**
44
45
**Why?** Because it add stability to training. In short, we are learning from estimated
46
targets and Target networks are updated slowly, hence keeping our estimated targets
47
stable.
48
49
Conceptually, this is like saying, "I have an idea of how to play this well,
50
I'm going to try it out for a bit until I find something better",
51
as opposed to saying "I'm going to re-learn how to play this entire game after every
52
move".
53
See this [StackOverflow answer](https://stackoverflow.com/a/54238556/13475679).
54
55
**Second, it uses Experience Replay.**
56
57
We store list of tuples `(state, action, reward, next_state)`, and instead of
58
learning only from recent experience, we learn from sampling all of our experience
59
accumulated so far.
60
61
Now, let's see how is it implemented.
62
"""
63
import os
64
65
os.environ["KERAS_BACKEND"] = "tensorflow"
66
67
import keras
68
from keras import layers
69
70
import tensorflow as tf
71
import gymnasium as gym
72
import numpy as np
73
import matplotlib.pyplot as plt
74
75
"""
76
We use [Gymnasium](https://gymnasium.farama.org/) to create the environment.
77
We will use the `upper_bound` parameter to scale our actions later.
78
"""
79
80
# Specify the `render_mode` parameter to show the attempts of the agent in a pop up window.
81
env = gym.make("Pendulum-v1") # , render_mode="human")
82
83
num_states = env.observation_space.shape[0]
84
print("Size of State Space -> {}".format(num_states))
85
num_actions = env.action_space.shape[0]
86
print("Size of Action Space -> {}".format(num_actions))
87
88
upper_bound = env.action_space.high[0]
89
lower_bound = env.action_space.low[0]
90
91
print("Max Value of Action -> {}".format(upper_bound))
92
print("Min Value of Action -> {}".format(lower_bound))
93
94
"""
95
To implement better exploration by the Actor network, we use noisy perturbations,
96
specifically
97
an **Ornstein-Uhlenbeck process** for generating noise, as described in the paper.
98
It samples noise from a correlated normal distribution.
99
"""
100
101
102
class OUActionNoise:
103
def __init__(self, mean, std_deviation, theta=0.15, dt=1e-2, x_initial=None):
104
self.theta = theta
105
self.mean = mean
106
self.std_dev = std_deviation
107
self.dt = dt
108
self.x_initial = x_initial
109
self.reset()
110
111
def __call__(self):
112
# Formula taken from https://www.wikipedia.org/wiki/Ornstein-Uhlenbeck_process
113
x = (
114
self.x_prev
115
+ self.theta * (self.mean - self.x_prev) * self.dt
116
+ self.std_dev * np.sqrt(self.dt) * np.random.normal(size=self.mean.shape)
117
)
118
# Store x into x_prev
119
# Makes next noise dependent on current one
120
self.x_prev = x
121
return x
122
123
def reset(self):
124
if self.x_initial is not None:
125
self.x_prev = self.x_initial
126
else:
127
self.x_prev = np.zeros_like(self.mean)
128
129
130
"""
131
The `Buffer` class implements Experience Replay.
132
133
---
134
![Algorithm](https://i.imgur.com/mS6iGyJ.jpg)
135
---
136
137
138
**Critic loss** - Mean Squared Error of `y - Q(s, a)`
139
where `y` is the expected return as seen by the Target network,
140
and `Q(s, a)` is action value predicted by the Critic network. `y` is a moving target
141
that the critic model tries to achieve; we make this target
142
stable by updating the Target model slowly.
143
144
**Actor loss** - This is computed using the mean of the value given by the Critic network
145
for the actions taken by the Actor network. We seek to maximize this quantity.
146
147
Hence we update the Actor network so that it produces actions that get
148
the maximum predicted value as seen by the Critic, for a given state.
149
"""
150
151
152
class Buffer:
153
def __init__(self, buffer_capacity=100000, batch_size=64):
154
# Number of "experiences" to store at max
155
self.buffer_capacity = buffer_capacity
156
# Num of tuples to train on.
157
self.batch_size = batch_size
158
159
# Its tells us num of times record() was called.
160
self.buffer_counter = 0
161
162
# Instead of list of tuples as the exp.replay concept go
163
# We use different np.arrays for each tuple element
164
self.state_buffer = np.zeros((self.buffer_capacity, num_states))
165
self.action_buffer = np.zeros((self.buffer_capacity, num_actions))
166
self.reward_buffer = np.zeros((self.buffer_capacity, 1))
167
self.next_state_buffer = np.zeros((self.buffer_capacity, num_states))
168
169
# Takes (s,a,r,s') observation tuple as input
170
def record(self, obs_tuple):
171
# Set index to zero if buffer_capacity is exceeded,
172
# replacing old records
173
index = self.buffer_counter % self.buffer_capacity
174
175
self.state_buffer[index] = obs_tuple[0]
176
self.action_buffer[index] = obs_tuple[1]
177
self.reward_buffer[index] = obs_tuple[2]
178
self.next_state_buffer[index] = obs_tuple[3]
179
180
self.buffer_counter += 1
181
182
# Eager execution is turned on by default in TensorFlow 2. Decorating with tf.function allows
183
# TensorFlow to build a static graph out of the logic and computations in our function.
184
# This provides a large speed up for blocks of code that contain many small TensorFlow operations such as this one.
185
@tf.function
186
def update(
187
self,
188
state_batch,
189
action_batch,
190
reward_batch,
191
next_state_batch,
192
):
193
# Training and updating Actor & Critic networks.
194
# See Pseudo Code.
195
with tf.GradientTape() as tape:
196
target_actions = target_actor(next_state_batch, training=True)
197
y = reward_batch + gamma * target_critic(
198
[next_state_batch, target_actions], training=True
199
)
200
critic_value = critic_model([state_batch, action_batch], training=True)
201
critic_loss = keras.ops.mean(keras.ops.square(y - critic_value))
202
203
critic_grad = tape.gradient(critic_loss, critic_model.trainable_variables)
204
critic_optimizer.apply_gradients(
205
zip(critic_grad, critic_model.trainable_variables)
206
)
207
208
with tf.GradientTape() as tape:
209
actions = actor_model(state_batch, training=True)
210
critic_value = critic_model([state_batch, actions], training=True)
211
# Used `-value` as we want to maximize the value given
212
# by the critic for our actions
213
actor_loss = -keras.ops.mean(critic_value)
214
215
actor_grad = tape.gradient(actor_loss, actor_model.trainable_variables)
216
actor_optimizer.apply_gradients(
217
zip(actor_grad, actor_model.trainable_variables)
218
)
219
220
# We compute the loss and update parameters
221
def learn(self):
222
# Get sampling range
223
record_range = min(self.buffer_counter, self.buffer_capacity)
224
# Randomly sample indices
225
batch_indices = np.random.choice(record_range, self.batch_size)
226
227
# Convert to tensors
228
state_batch = keras.ops.convert_to_tensor(self.state_buffer[batch_indices])
229
action_batch = keras.ops.convert_to_tensor(self.action_buffer[batch_indices])
230
reward_batch = keras.ops.convert_to_tensor(self.reward_buffer[batch_indices])
231
reward_batch = keras.ops.cast(reward_batch, dtype="float32")
232
next_state_batch = keras.ops.convert_to_tensor(
233
self.next_state_buffer[batch_indices]
234
)
235
236
self.update(state_batch, action_batch, reward_batch, next_state_batch)
237
238
239
# This update target parameters slowly
240
# Based on rate `tau`, which is much less than one.
241
def update_target(target, original, tau):
242
target_weights = target.get_weights()
243
original_weights = original.get_weights()
244
245
for i in range(len(target_weights)):
246
target_weights[i] = original_weights[i] * tau + target_weights[i] * (1 - tau)
247
248
target.set_weights(target_weights)
249
250
251
"""
252
Here we define the Actor and Critic networks. These are basic Dense models
253
with `ReLU` activation.
254
255
Note: We need the initialization for last layer of the Actor to be between
256
`-0.003` and `0.003` as this prevents us from getting `1` or `-1` output values in
257
the initial stages, which would squash our gradients to zero,
258
as we use the `tanh` activation.
259
"""
260
261
262
def get_actor():
263
# Initialize weights between -3e-3 and 3-e3
264
last_init = keras.initializers.RandomUniform(minval=-0.003, maxval=0.003)
265
266
inputs = layers.Input(shape=(num_states,))
267
out = layers.Dense(256, activation="relu")(inputs)
268
out = layers.Dense(256, activation="relu")(out)
269
outputs = layers.Dense(1, activation="tanh", kernel_initializer=last_init)(out)
270
271
# Our upper bound is 2.0 for Pendulum.
272
outputs = outputs * upper_bound
273
model = keras.Model(inputs, outputs)
274
return model
275
276
277
def get_critic():
278
# State as input
279
state_input = layers.Input(shape=(num_states,))
280
state_out = layers.Dense(16, activation="relu")(state_input)
281
state_out = layers.Dense(32, activation="relu")(state_out)
282
283
# Action as input
284
action_input = layers.Input(shape=(num_actions,))
285
action_out = layers.Dense(32, activation="relu")(action_input)
286
287
# Both are passed through separate layer before concatenating
288
concat = layers.Concatenate()([state_out, action_out])
289
290
out = layers.Dense(256, activation="relu")(concat)
291
out = layers.Dense(256, activation="relu")(out)
292
outputs = layers.Dense(1)(out)
293
294
# Outputs single value for give state-action
295
model = keras.Model([state_input, action_input], outputs)
296
297
return model
298
299
300
"""
301
`policy()` returns an action sampled from our Actor network plus some noise for
302
exploration.
303
"""
304
305
306
def policy(state, noise_object):
307
sampled_actions = keras.ops.squeeze(actor_model(state))
308
noise = noise_object()
309
# Adding noise to action
310
sampled_actions = sampled_actions.numpy() + noise
311
312
# We make sure action is within bounds
313
legal_action = np.clip(sampled_actions, lower_bound, upper_bound)
314
315
return [np.squeeze(legal_action)]
316
317
318
"""
319
## Training hyperparameters
320
"""
321
322
std_dev = 0.2
323
ou_noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(std_dev) * np.ones(1))
324
325
actor_model = get_actor()
326
critic_model = get_critic()
327
328
target_actor = get_actor()
329
target_critic = get_critic()
330
331
# Making the weights equal initially
332
target_actor.set_weights(actor_model.get_weights())
333
target_critic.set_weights(critic_model.get_weights())
334
335
# Learning rate for actor-critic models
336
critic_lr = 0.002
337
actor_lr = 0.001
338
339
critic_optimizer = keras.optimizers.Adam(critic_lr)
340
actor_optimizer = keras.optimizers.Adam(actor_lr)
341
342
total_episodes = 100
343
# Discount factor for future rewards
344
gamma = 0.99
345
# Used to update target networks
346
tau = 0.005
347
348
buffer = Buffer(50000, 64)
349
350
"""
351
Now we implement our main training loop, and iterate over episodes.
352
We sample actions using `policy()` and train with `learn()` at each time step,
353
along with updating the Target networks at a rate `tau`.
354
"""
355
356
# To store reward history of each episode
357
ep_reward_list = []
358
# To store average reward history of last few episodes
359
avg_reward_list = []
360
361
# Takes about 4 min to train
362
for ep in range(total_episodes):
363
prev_state, _ = env.reset()
364
episodic_reward = 0
365
366
while True:
367
tf_prev_state = keras.ops.expand_dims(
368
keras.ops.convert_to_tensor(prev_state), 0
369
)
370
371
action = policy(tf_prev_state, ou_noise)
372
# Receive state and reward from environment.
373
state, reward, done, truncated, _ = env.step(action)
374
375
buffer.record((prev_state, action, reward, state))
376
episodic_reward += reward
377
378
buffer.learn()
379
380
update_target(target_actor, actor_model, tau)
381
update_target(target_critic, critic_model, tau)
382
383
# End this episode when `done` or `truncated` is True
384
if done or truncated:
385
break
386
387
prev_state = state
388
389
ep_reward_list.append(episodic_reward)
390
391
# Mean of last 40 episodes
392
avg_reward = np.mean(ep_reward_list[-40:])
393
print("Episode * {} * Avg Reward is ==> {}".format(ep, avg_reward))
394
avg_reward_list.append(avg_reward)
395
396
# Plotting graph
397
# Episodes versus Avg. Rewards
398
plt.plot(avg_reward_list)
399
plt.xlabel("Episode")
400
plt.ylabel("Avg. Episodic Reward")
401
plt.show()
402
403
"""
404
If training proceeds correctly, the average episodic reward will increase with time.
405
406
Feel free to try different learning rates, `tau` values, and architectures for the
407
Actor and Critic networks.
408
409
The Inverted Pendulum problem has low complexity, but DDPG work great on many other
410
problems.
411
412
Another great environment to try this on is `LunarLander-v2` continuous, but it will take
413
more episodes to obtain good results.
414
"""
415
416
# Save the weights
417
actor_model.save_weights("pendulum_actor.weights.h5")
418
critic_model.save_weights("pendulum_critic.weights.h5")
419
420
target_actor.save_weights("pendulum_target_actor.weights.h5")
421
target_critic.save_weights("pendulum_target_critic.weights.h5")
422
423
"""
424
Before Training:
425
426
![before_img](https://i.imgur.com/ox6b9rC.gif)
427
"""
428
429
"""
430
After 100 episodes:
431
432
![after_img](https://i.imgur.com/eEH8Cz6.gif)
433
"""
434
435