Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/notebooks/notebook_tabularq.py
597 views
1
# ---
2
# jupyter:
3
# jupytext:
4
# cell_metadata_filter: title,-all
5
# formats: py:percent,ipynb
6
# text_representation:
7
# extension: .py
8
# format_name: percent
9
# format_version: '1.3'
10
# jupytext_version: 1.16.4
11
# kernelspec:
12
# display_name: cybersim
13
# language: python
14
# name: cybersim
15
# ---
16
17
# %% [markdown]
18
# pyright: reportUnusedExpression=false
19
20
# %%
21
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
23
24
"""Tabular Q-learning agent (notebook)
25
26
This notebooks can be run directly from VSCode, to generate a
27
traditional Jupyter Notebook to open in your browser
28
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
29
"""
30
31
# pylint: disable=invalid-name
32
33
# %%
34
import sys
35
import os
36
import logging
37
from typing import cast
38
import gymnasium as gym
39
import numpy as np
40
import matplotlib.pyplot as plt # type: ignore
41
from cyberbattle.agents.baseline.learner import TrainedLearner
42
import cyberbattle.agents.baseline.plotting as p
43
import cyberbattle.agents.baseline.agent_wrapper as w
44
import cyberbattle.agents.baseline.agent_tabularqlearning as a
45
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
46
import cyberbattle.agents.baseline.learner as learner
47
import cyberbattle._env.cyberbattle_env as cyberbattle_env
48
from cyberbattle._env.cyberbattle_env import AttackerGoal
49
50
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
51
52
# %%
53
# %matplotlib inline
54
55
# %%
56
# Benchmark parameters:
57
# Parameters from DeepDoubleQ paper
58
# - learning_rate = 0.00025
59
# - linear epsilon decay
60
# - gamma = 0.99
61
# Eliminated gamma_values
62
# 0.0,
63
# 0.0015, # too small
64
# 0.15, # too big
65
# 0.25, # too big
66
# 0.35, # too big
67
#
68
# NOTE: Given the relatively low number of training episodes (50,
69
# a high learning rate of .99 gives better result
70
# than a lower learning rate of 0.25 (i.e. maximal rewards reached faster on average).
71
# Ideally we should decay the learning rate just like gamma and train over a
72
# much larger number of episodes
73
74
cyberbattlechain_10 = gym.make("CyberBattleChain-v0", size=10, attacker_goal=AttackerGoal(own_atleast_percent=1.0)).unwrapped
75
76
assert isinstance(cyberbattlechain_10, cyberbattle_env.CyberBattleEnv)
77
78
ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=12, maximum_total_credentials=12, identifiers=cyberbattlechain_10.identifiers)
79
80
# %% {"tags": ["parameters"]}
81
iteration_count = 9000
82
training_episode_count = 5
83
eval_episode_count = 5
84
gamma_sweep = [
85
0.015, # about right
86
]
87
plots_dir = 'output/plots'
88
89
# %%
90
os.makedirs(plots_dir, exist_ok=True)
91
92
93
# %%
94
def qlearning_run(gamma, gym_env):
95
"""Execute one run of the q-learning algorithm for the
96
specified gamma value"""
97
return learner.epsilon_greedy_search(
98
gym_env,
99
ep,
100
a.QTabularLearner(ep, gamma=gamma, learning_rate=0.90, exploit_percentile=100),
101
episode_count=training_episode_count,
102
iteration_count=iteration_count,
103
epsilon=0.90,
104
render=False,
105
epsilon_multdecay=0.75, # 0.999,
106
epsilon_minimum=0.01,
107
verbosity=Verbosity.Quiet,
108
title="Q-learning",
109
)
110
111
112
# %%
113
# %%
114
# Run Q-learning with gamma-sweep
115
qlearning_results = [qlearning_run(gamma, cyberbattlechain_10) for gamma in gamma_sweep]
116
117
# %%
118
qlearning_bestrun_10 = qlearning_results[0]
119
# %%
120
121
# %%
122
p.new_plot_loss()
123
for results in qlearning_results:
124
p.plot_all_episodes_loss(cast(a.QTabularLearner, results["learner"]).loss_qsource.all_episodes, "Q_source", results["title"])
125
p.plot_all_episodes_loss(cast(a.QTabularLearner, results["learner"]).loss_qattack.all_episodes, "Q_attack", results["title"])
126
plt.legend(loc="upper right")
127
plt.show()
128
129
# %%
130
# %% Plot episode length
131
p.plot_episodes_length(qlearning_results)
132
133
# %% [markdown]
134
135
# %%
136
nolearning_results = learner.epsilon_greedy_search(
137
cyberbattlechain_10,
138
ep,
139
learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=100),
140
episode_count=eval_episode_count,
141
iteration_count=iteration_count,
142
epsilon=0.30, # 0.35,
143
render=False,
144
title="Exploiting Q-matrix",
145
verbosity=Verbosity.Quiet,
146
)
147
148
# %%
149
randomlearning_results = learner.epsilon_greedy_search(
150
cyberbattlechain_10,
151
ep,
152
learner=a.QTabularLearner(ep, trained=qlearning_bestrun_10["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=100),
153
episode_count=eval_episode_count,
154
iteration_count=iteration_count,
155
epsilon=1.0, # purely random
156
render=False,
157
verbosity=Verbosity.Quiet,
158
title="Random search",
159
)
160
161
# %%
162
# Plot averaged cumulative rewards for Q-learning vs Random vs Q-Exploit
163
all_runs = [*qlearning_results, randomlearning_results, nolearning_results]
164
165
Q_source_10 = cast(a.QTabularLearner, qlearning_bestrun_10["learner"]).qsource
166
Q_attack_10 = cast(a.QTabularLearner, qlearning_bestrun_10["learner"]).qattack
167
168
p.plot_averaged_cummulative_rewards(
169
all_runs=all_runs,
170
title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"
171
f"dimension={Q_source_10.state_space.flat_size()}x{Q_source_10.action_space.flat_size()}, "
172
f"{Q_attack_10.state_space.flat_size()}x{Q_attack_10.action_space.flat_size()}\n"
173
f"Q1={[f.name() for f in Q_source_10.state_space.feature_selection]} "
174
f"-> {[f.name() for f in Q_source_10.action_space.feature_selection]})\n"
175
f"Q2={[f.name() for f in Q_attack_10.state_space.feature_selection]} -> 'action'",
176
save_at=os.path.join(plots_dir, "benchmark-tabularq-cumrewards.png")
177
)
178
179
180
# %%
181
# plot cumulative rewards for all episodes
182
p.plot_all_episodes(qlearning_results[0])
183
184
185
# %%
186
# Plot the Q-matrices
187
188
# %%
189
# Print non-zero coordinate in the Q matrix Q_source
190
i = np.where(Q_source_10.qm)
191
q = Q_source_10.qm[i]
192
list(zip(np.array([Q_source_10.state_space.pretty_print(i) for i in i[0]]), np.array([Q_source_10.action_space.pretty_print(i) for i in i[1]]), q))
193
194
# %%
195
# Print non-zero coordinate in the Q matrix Q_attack
196
i2 = np.where(Q_attack_10.qm)
197
q2 = Q_attack_10.qm[i2]
198
list(zip([Q_attack_10.state_space.pretty_print(i) for i in i2[0]], [Q_attack_10.action_space.pretty_print(i) for i in i2[1]], q2))
199
200
201
##################################################
202
203
# %% [markdown]
204
# ## Transfer learning from size 4 to size 10
205
# Exploiting Q-matrix learned from a different network.
206
207
# %%
208
# Train Q-matrix on CyberBattle network of size 4
209
cyberbattlechain_4 = gym.make("CyberBattleChain-v0", size=4, attacker_goal=AttackerGoal(own_atleast_percent=1.0)).unwrapped
210
assert isinstance(cyberbattlechain_4, cyberbattle_env.CyberBattleEnv)
211
212
qlearning_bestrun_4 = qlearning_run(0.015, gym_env=cyberbattlechain_4)
213
214
215
def stop_learning(trained_learner):
216
return TrainedLearner(
217
learner=a.QTabularLearner(ep, gamma=0.0, learning_rate=0.0, exploit_percentile=0, trained=trained_learner["learner"]),
218
title=trained_learner["title"],
219
trained_on=trained_learner["trained_on"],
220
all_episodes_rewards=trained_learner["all_episodes_rewards"],
221
all_episodes_availability=trained_learner["all_episodes_availability"],
222
)
223
224
225
learner.transfer_learning_evaluation(
226
environment_properties=ep,
227
trained_learner=stop_learning(qlearning_bestrun_4),
228
eval_env=cyberbattlechain_10,
229
eval_epsilon=0.5, # alternate with exploration to help generalization to bigger network
230
eval_episode_count=eval_episode_count,
231
iteration_count=iteration_count,
232
)
233
234
learner.transfer_learning_evaluation(
235
environment_properties=ep,
236
trained_learner=stop_learning(qlearning_bestrun_10),
237
eval_env=cyberbattlechain_4,
238
eval_epsilon=0.5,
239
eval_episode_count=eval_episode_count,
240
iteration_count=iteration_count,
241
)
242
243