Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/rl/deep_q_network_breakout.py
3507 views
1
"""
2
Title: Deep Q-Learning for Atari Breakout
3
Author: [Jacob Chapman](https://twitter.com/jacoblchapman) and [Mathias Lechner](https://twitter.com/MLech20)
4
Date created: 2020/05/23
5
Last modified: 2024/03/17
6
Description: Play Atari Breakout with a Deep Q-Network.
7
Accelerator: None
8
"""
9
10
"""
11
## Introduction
12
13
This script shows an implementation of Deep Q-Learning on the
14
`BreakoutNoFrameskip-v4` environment.
15
16
### Deep Q-Learning
17
18
As an agent takes actions and moves through an environment, it learns to map
19
the observed state of the environment to an action. An agent will choose an action
20
in a given state based on a "Q-value", which is a weighted reward based on the
21
expected highest long-term reward. A Q-Learning Agent learns to perform its
22
task such that the recommended action maximizes the potential future rewards.
23
This method is considered an "Off-Policy" method,
24
meaning its Q values are updated assuming that the best action was chosen, even
25
if the best action was not chosen.
26
27
### Atari Breakout
28
29
In this environment, a board moves along the bottom of the screen returning a ball that
30
will destroy blocks at the top of the screen.
31
The aim of the game is to remove all blocks and breakout of the
32
level. The agent must learn to control the board by moving left and right, returning the
33
ball and removing all the blocks without the ball passing the board.
34
35
### Note
36
37
The Deepmind paper trained for "a total of 50 million frames (that is, around 38 days of
38
game experience in total)". However this script will give good results at around 10
39
million frames which are processed in less than 24 hours on a modern machine.
40
41
You can control the number of episodes by setting the `max_episodes` variable
42
to a value greater than 0.
43
44
### References
45
46
- [Q-Learning](https://link.springer.com/content/pdf/10.1007/BF00992698.pdf)
47
- [Deep Q-Learning](https://www.semanticscholar.org/paper/Human-level-control-through-deep-reinforcement-Mnih-Kavukcuoglu/340f48901f72278f6bf78a04ee5b01df208cc508)
48
"""
49
"""
50
## Setup
51
"""
52
53
import os
54
55
os.environ["KERAS_BACKEND"] = "tensorflow"
56
57
import keras
58
from keras import layers
59
60
import gymnasium as gym
61
from gymnasium.wrappers import AtariPreprocessing, FrameStack
62
import numpy as np
63
import tensorflow as tf
64
65
# Configuration parameters for the whole setup
66
seed = 42
67
gamma = 0.99 # Discount factor for past rewards
68
epsilon = 1.0 # Epsilon greedy parameter
69
epsilon_min = 0.1 # Minimum epsilon greedy parameter
70
epsilon_max = 1.0 # Maximum epsilon greedy parameter
71
epsilon_interval = (
72
epsilon_max - epsilon_min
73
) # Rate at which to reduce chance of random action being taken
74
batch_size = 32 # Size of batch taken from replay buffer
75
max_steps_per_episode = 10000
76
max_episodes = 10 # Limit training episodes, will run until solved if smaller than 1
77
78
# Use the Atari environment
79
# Specify the `render_mode` parameter to show the attempts of the agent in a pop up window.
80
env = gym.make("BreakoutNoFrameskip-v4") # , render_mode="human")
81
# Environment preprocessing
82
env = AtariPreprocessing(env)
83
# Stack four frames
84
env = FrameStack(env, 4)
85
env.seed(seed)
86
"""
87
## Implement the Deep Q-Network
88
89
This network learns an approximation of the Q-table, which is a mapping between
90
the states and actions that an agent will take. For every state we'll have four
91
actions, that can be taken. The environment provides the state, and the action
92
is chosen by selecting the larger of the four Q-values predicted in the output layer.
93
94
"""
95
96
num_actions = 4
97
98
99
def create_q_model():
100
# Network defined by the Deepmind paper
101
return keras.Sequential(
102
[
103
layers.Lambda(
104
lambda tensor: keras.ops.transpose(tensor, [0, 2, 3, 1]),
105
output_shape=(84, 84, 4),
106
input_shape=(4, 84, 84),
107
),
108
# Convolutions on the frames on the screen
109
layers.Conv2D(32, 8, strides=4, activation="relu", input_shape=(4, 84, 84)),
110
layers.Conv2D(64, 4, strides=2, activation="relu"),
111
layers.Conv2D(64, 3, strides=1, activation="relu"),
112
layers.Flatten(),
113
layers.Dense(512, activation="relu"),
114
layers.Dense(num_actions, activation="linear"),
115
]
116
)
117
118
119
# The first model makes the predictions for Q-values which are used to
120
# make a action.
121
model = create_q_model()
122
# Build a target model for the prediction of future rewards.
123
# The weights of a target model get updated every 10000 steps thus when the
124
# loss between the Q-values is calculated the target Q-value is stable.
125
model_target = create_q_model()
126
127
128
"""
129
## Train
130
"""
131
# In the Deepmind paper they use RMSProp however then Adam optimizer
132
# improves training time
133
optimizer = keras.optimizers.Adam(learning_rate=0.00025, clipnorm=1.0)
134
135
# Experience replay buffers
136
action_history = []
137
state_history = []
138
state_next_history = []
139
rewards_history = []
140
done_history = []
141
episode_reward_history = []
142
running_reward = 0
143
episode_count = 0
144
frame_count = 0
145
# Number of frames to take random action and observe output
146
epsilon_random_frames = 50000
147
# Number of frames for exploration
148
epsilon_greedy_frames = 1000000.0
149
# Maximum replay length
150
# Note: The Deepmind paper suggests 1000000 however this causes memory issues
151
max_memory_length = 100000
152
# Train the model after 4 actions
153
update_after_actions = 4
154
# How often to update the target network
155
update_target_network = 10000
156
# Using huber loss for stability
157
loss_function = keras.losses.Huber()
158
159
while True:
160
observation, _ = env.reset()
161
state = np.array(observation)
162
episode_reward = 0
163
164
for timestep in range(1, max_steps_per_episode):
165
frame_count += 1
166
167
# Use epsilon-greedy for exploration
168
if frame_count < epsilon_random_frames or epsilon > np.random.rand(1)[0]:
169
# Take random action
170
action = np.random.choice(num_actions)
171
else:
172
# Predict action Q-values
173
# From environment state
174
state_tensor = keras.ops.convert_to_tensor(state)
175
state_tensor = keras.ops.expand_dims(state_tensor, 0)
176
action_probs = model(state_tensor, training=False)
177
# Take best action
178
action = keras.ops.argmax(action_probs[0]).numpy()
179
180
# Decay probability of taking random action
181
epsilon -= epsilon_interval / epsilon_greedy_frames
182
epsilon = max(epsilon, epsilon_min)
183
184
# Apply the sampled action in our environment
185
state_next, reward, done, _, _ = env.step(action)
186
state_next = np.array(state_next)
187
188
episode_reward += reward
189
190
# Save actions and states in replay buffer
191
action_history.append(action)
192
state_history.append(state)
193
state_next_history.append(state_next)
194
done_history.append(done)
195
rewards_history.append(reward)
196
state = state_next
197
198
# Update every fourth frame and once batch size is over 32
199
if frame_count % update_after_actions == 0 and len(done_history) > batch_size:
200
# Get indices of samples for replay buffers
201
indices = np.random.choice(range(len(done_history)), size=batch_size)
202
203
# Using list comprehension to sample from replay buffer
204
state_sample = np.array([state_history[i] for i in indices])
205
state_next_sample = np.array([state_next_history[i] for i in indices])
206
rewards_sample = [rewards_history[i] for i in indices]
207
action_sample = [action_history[i] for i in indices]
208
done_sample = keras.ops.convert_to_tensor(
209
[float(done_history[i]) for i in indices]
210
)
211
212
# Build the updated Q-values for the sampled future states
213
# Use the target model for stability
214
future_rewards = model_target.predict(state_next_sample)
215
# Q value = reward + discount factor * expected future reward
216
updated_q_values = rewards_sample + gamma * keras.ops.amax(
217
future_rewards, axis=1
218
)
219
220
# If final frame set the last value to -1
221
updated_q_values = updated_q_values * (1 - done_sample) - done_sample
222
223
# Create a mask so we only calculate loss on the updated Q-values
224
masks = keras.ops.one_hot(action_sample, num_actions)
225
226
with tf.GradientTape() as tape:
227
# Train the model on the states and updated Q-values
228
q_values = model(state_sample)
229
230
# Apply the masks to the Q-values to get the Q-value for action taken
231
q_action = keras.ops.sum(keras.ops.multiply(q_values, masks), axis=1)
232
# Calculate loss between new Q-value and old Q-value
233
loss = loss_function(updated_q_values, q_action)
234
235
# Backpropagation
236
grads = tape.gradient(loss, model.trainable_variables)
237
optimizer.apply_gradients(zip(grads, model.trainable_variables))
238
239
if frame_count % update_target_network == 0:
240
# update the the target network with new weights
241
model_target.set_weights(model.get_weights())
242
# Log details
243
template = "running reward: {:.2f} at episode {}, frame count {}"
244
print(template.format(running_reward, episode_count, frame_count))
245
246
# Limit the state and reward history
247
if len(rewards_history) > max_memory_length:
248
del rewards_history[:1]
249
del state_history[:1]
250
del state_next_history[:1]
251
del action_history[:1]
252
del done_history[:1]
253
254
if done:
255
break
256
257
# Update running reward to check condition for solving
258
episode_reward_history.append(episode_reward)
259
if len(episode_reward_history) > 100:
260
del episode_reward_history[:1]
261
running_reward = np.mean(episode_reward_history)
262
263
episode_count += 1
264
265
if running_reward > 40: # Condition to consider the task solved
266
print("Solved at episode {}!".format(episode_count))
267
break
268
269
if (
270
max_episodes > 0 and episode_count >= max_episodes
271
): # Maximum number of episodes reached
272
print("Stopped at episode {}!".format(episode_count))
273
break
274
275
"""
276
## Visualizations
277
Before any training:
278
![Imgur](https://i.imgur.com/rRxXF4H.gif)
279
280
In early stages of training:
281
![Imgur](https://i.imgur.com/X8ghdpL.gif)
282
283
In later stages of training:
284
![Imgur](https://i.imgur.com/Z1K6qBQ.gif)
285
"""
286
287