Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/notebooks/notebook_benchmark.py
597 views
1
# ---
2
# jupyter:
3
# jupytext:
4
# cell_metadata_filter: tags,-all
5
# cell_metadata_json: true
6
# formats: ipynb,py:percent
7
# text_representation:
8
# extension: .py
9
# format_name: percent
10
# format_version: '1.3'
11
# jupytext_version: 1.16.4
12
# kernelspec:
13
# display_name: Python 3 (ipykernel)
14
# language: python
15
# name: python3
16
# ---
17
18
# %% {"tags": []}
19
# Copyright (c) Microsoft Corporation.
20
# Licensed under the MIT License.
21
22
"""Benchmark all the baseline agents
23
on a given CyberBattleSim environment and compare
24
them to the dumb 'random agent' baseline.
25
26
NOTE: You can run this `.py`-notebook directly from VSCode.
27
You can also generate a traditional Jupyter Notebook
28
using the VSCode command `Export Currenty Python File As Jupyter Notebook`.
29
"""
30
31
# pylint: disable=invalid-name
32
33
# %% {"tags": []}
34
import sys
35
import os
36
import logging
37
import gymnasium as gym
38
import cyberbattle.agents.baseline.learner as learner
39
import cyberbattle.agents.baseline.plotting as p
40
import cyberbattle.agents.baseline.agent_wrapper as w
41
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
42
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
43
import cyberbattle.agents.baseline.agent_dql as dqla
44
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
45
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
46
47
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
48
# %% {"tags": []}
49
# %matplotlib inline
50
# %% {"tags": ["parameters"]}
51
# Papermill notebook parameters
52
gymid = "CyberBattleChain-v0"
53
env_size = 10
54
iteration_count = 9000
55
training_episode_count = 50
56
eval_episode_count = 5
57
maximum_node_count = 22
58
maximum_total_credentials = 22
59
plots_dir = "output/plots"
60
61
# %% {"tags": []}
62
os.makedirs(plots_dir, exist_ok=True)
63
64
# Load the Gym environment
65
if env_size:
66
_gym_env = gym.make(gymid, size=env_size)
67
else:
68
_gym_env = gym.make(gymid)
69
70
from typing import cast
71
72
gym_env = cast(CyberBattleEnv, _gym_env.unwrapped)
73
assert isinstance(gym_env, CyberBattleEnv), f"Expected CyberBattleEnv, got {type(gym_env)}"
74
75
ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=maximum_node_count, maximum_total_credentials=maximum_total_credentials, identifiers=gym_env.identifiers)
76
77
# %% {"tags": []}
78
debugging = False
79
if debugging:
80
print(f"port_count = {ep.port_count}, property_count = {ep.property_count}")
81
82
gym_env.environment
83
# training_env.environment.plot_environment_graph()
84
gym_env.environment.network.nodes
85
gym_env.action_space
86
gym_env.action_space.sample()
87
gym_env.observation_space.sample()
88
o0, _ = gym_env.reset()
89
o_test, r, d, t, i = gym_env.step(gym_env.sample_valid_action())
90
o0, _ = gym_env.reset()
91
92
o0.keys()
93
94
fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])
95
a = w.StateAugmentation(o0)
96
w.Feature_discovered_ports(ep).get(a)
97
fe_example.encode_at(a, 0)
98
99
# %% {"tags": []}
100
# Evaluate a random agent that opportunistically exploits
101
# credentials gathere in its local cache
102
credlookup_run = learner.epsilon_greedy_search(
103
gym_env,
104
ep,
105
learner=rca.CredentialCacheExploiter(),
106
episode_count=10,
107
iteration_count=iteration_count,
108
epsilon=0.90,
109
render=False,
110
epsilon_exponential_decay=10000,
111
epsilon_minimum=0.10,
112
verbosity=Verbosity.Quiet,
113
title="Credential lookups (ϵ-greedy)",
114
)
115
116
# %% {"tags": []}
117
# Evaluate a Tabular Q-learning agent
118
tabularq_run = learner.epsilon_greedy_search(
119
gym_env,
120
ep,
121
learner=tqa.QTabularLearner(ep, gamma=0.015, learning_rate=0.01, exploit_percentile=100),
122
episode_count=training_episode_count,
123
iteration_count=iteration_count,
124
epsilon=0.90,
125
epsilon_exponential_decay=5000,
126
epsilon_minimum=0.01,
127
verbosity=Verbosity.Quiet,
128
render=False,
129
plot_episodes_length=False,
130
title="Tabular Q-learning",
131
)
132
133
# %% {"tags": []}
134
# Evaluate an agent that exploits the Q-table learnt above
135
tabularq_exploit_run = learner.epsilon_greedy_search(
136
gym_env,
137
ep,
138
learner=tqa.QTabularLearner(ep, trained=tabularq_run["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=90),
139
episode_count=eval_episode_count,
140
iteration_count=iteration_count,
141
epsilon=0.0,
142
render=False,
143
verbosity=Verbosity.Quiet,
144
title="Exploiting Q-matrix",
145
)
146
147
# %% {"tags": []}
148
# Evaluate the Deep Q-learning agent
149
dql_run = learner.epsilon_greedy_search(
150
cyberbattle_gym_env=gym_env,
151
environment_properties=ep,
152
learner=dqla.DeepQLearnerPolicy(
153
ep=ep,
154
gamma=0.015,
155
replay_memory_size=10000,
156
target_update=10,
157
batch_size=512,
158
# torch default learning rate is 1e-2
159
# a large value helps converge in less episodes
160
learning_rate=0.01,
161
),
162
episode_count=training_episode_count,
163
iteration_count=iteration_count,
164
epsilon=0.90,
165
epsilon_exponential_decay=5000,
166
epsilon_minimum=0.10,
167
verbosity=Verbosity.Quiet,
168
render=False,
169
plot_episodes_length=False,
170
title="DQL",
171
)
172
173
# %% {"tags": []}
174
# Evaluate an agent that exploits the Q-function learnt above
175
dql_exploit_run = learner.epsilon_greedy_search(
176
gym_env,
177
ep,
178
learner=dql_run["learner"],
179
episode_count=eval_episode_count,
180
iteration_count=iteration_count,
181
epsilon=0.0,
182
epsilon_minimum=0.00,
183
render=False,
184
plot_episodes_length=False,
185
verbosity=Verbosity.Quiet,
186
title="Exploiting DQL",
187
)
188
189
190
# %% {"tags": []}
191
# Evaluate the random agent
192
random_run = learner.epsilon_greedy_search(
193
gym_env,
194
ep,
195
learner=learner.RandomPolicy(),
196
episode_count=eval_episode_count,
197
iteration_count=iteration_count,
198
epsilon=1.0, # purely random
199
render=False,
200
verbosity=Verbosity.Quiet,
201
plot_episodes_length=False,
202
title="Random search",
203
)
204
205
# %% {"tags": []}
206
# Compare and plot results for all the agents
207
all_runs = [random_run, credlookup_run, tabularq_run, tabularq_exploit_run, dql_run, dql_exploit_run]
208
209
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
210
themodel = dqla.CyberBattleStateActionModel(ep)
211
p.plot_averaged_cummulative_rewards(
212
all_runs=all_runs,
213
title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"
214
f"State: {[f.name() for f in themodel.state_space.feature_selection]} "
215
f"({len(themodel.state_space.feature_selection)}\n"
216
f"Action: abstract_action ({themodel.action_space.flat_size()})",
217
save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumrewards.png"),
218
)
219
220
# %% {"tags": []}
221
contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]
222
p.plot_episodes_length(contenders)
223
p.plot_averaged_cummulative_rewards(title=f"Agent Benchmark top contenders\n" f"max_nodes:{ep.maximum_node_count}\n", all_runs=contenders,
224
save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumreward_contenders.png"))
225
226
227
# %% {"tags": []}
228
# Plot cumulative rewards for all episodes
229
for r in contenders:
230
p.plot_all_episodes(r)
231
232