Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/notebooks/notebook_dql_transfer.py
597 views
1
# ---
2
# jupyter:
3
# jupytext:
4
# cell_metadata_filter: title,-all
5
# formats: py:percent,ipynb
6
# text_representation:
7
# extension: .py
8
# format_name: percent
9
# format_version: '1.3'
10
# jupytext_version: 1.16.4
11
# kernelspec:
12
# display_name: Python 3 (ipykernel)
13
# language: python
14
# name: python3
15
# ---
16
17
# %%
18
# Copyright (c) Microsoft Corporation.
19
# Licensed under the MIT License.
20
21
# -*- coding: utf-8 -*-
22
# %%
23
"""Notebook demonstrating transfer learning capability of the
24
the Deep Q-learning agent trained and evaluated on the chain
25
environment of various sizes.
26
27
NOTE: You can run this `.py`-notebook directly from VSCode.
28
You can also generate a traditional Jupyter Notebook
29
using the VSCode command `Export Currenty Python File As Jupyter Notebook`.
30
"""
31
32
# %%
33
import os
34
import sys
35
import logging
36
import gymnasium as gym
37
import torch
38
39
import cyberbattle.agents.baseline.learner as learner
40
import cyberbattle.agents.baseline.plotting as p
41
import cyberbattle.agents.baseline.agent_wrapper as w
42
import cyberbattle.agents.baseline.agent_dql as dqla
43
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
44
import cyberbattle.agents.baseline.agent_randomcredlookup as rca
45
import importlib
46
import cyberbattle._env.cyberbattle_env as cyberbattle_env
47
import cyberbattle._env.cyberbattle_chain as cyberbattle_chain
48
49
importlib.reload(learner)
50
importlib.reload(cyberbattle_env)
51
importlib.reload(cyberbattle_chain)
52
53
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
54
55
# %matplotlib inline
56
57
# %%
58
torch.cuda.is_available()
59
60
# %%
61
# To run once
62
# import plotly.io as pio
63
# pio.orca.config.use_xvfb = True
64
# pio.orca.config.save()
65
# %%
66
cyberbattlechain_4 = gym.make("CyberBattleChain-v0", size=4, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0)).unwrapped
67
cyberbattlechain_10 = gym.make("CyberBattleChain-v0", size=10, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0)).unwrapped
68
cyberbattlechain_20 = gym.make("CyberBattleChain-v0", size=20, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0)).unwrapped
69
70
assert isinstance(cyberbattlechain_4, cyberbattle_env.CyberBattleEnv)
71
assert isinstance(cyberbattlechain_10, cyberbattle_env.CyberBattleEnv)
72
assert isinstance(cyberbattlechain_20, cyberbattle_env.CyberBattleEnv)
73
74
ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=22, maximum_node_count=22, identifiers=cyberbattlechain_10.identifiers)
75
76
# %% {"tags": ["parameters"]}
77
iteration_count = 9000
78
training_episode_count = 50
79
eval_episode_count = 10
80
plots_dir = "output/images"
81
82
# %%
83
os.makedirs(plots_dir, exist_ok=True)
84
85
# %%
86
# Run Deep Q-learning
87
# 0.015
88
best_dqn_learning_run_10 = learner.epsilon_greedy_search(
89
cyberbattle_gym_env=cyberbattlechain_10,
90
environment_properties=ep,
91
learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.015, replay_memory_size=10000, target_update=10, batch_size=512, learning_rate=0.01), # torch default is 1e-2
92
episode_count=training_episode_count,
93
iteration_count=iteration_count,
94
epsilon=0.90,
95
render=False,
96
# epsilon_multdecay=0.75, # 0.999,
97
epsilon_exponential_decay=5000, # 10000
98
epsilon_minimum=0.10,
99
verbosity=Verbosity.Quiet,
100
title="DQL",
101
)
102
103
# %%
104
# %% Plot episode length
105
p.plot_episodes_length([best_dqn_learning_run_10])
106
107
# %% [markdown]
108
109
# %%
110
if not os.path.exists("images"):
111
os.mkdir("images")
112
113
# %%
114
dql_exploit_run = learner.epsilon_greedy_search(
115
cyberbattlechain_10,
116
ep,
117
learner=best_dqn_learning_run_10["learner"],
118
episode_count=eval_episode_count,
119
iteration_count=iteration_count,
120
epsilon=0.0, # 0.35,
121
render=False,
122
render_last_episode_rewards_to=os.path.join(plots_dir, "dql_transfer-chain10"),
123
title="Exploiting DQL",
124
verbosity=Verbosity.Quiet,
125
)
126
127
128
# %%
129
random_run = learner.epsilon_greedy_search(
130
cyberbattlechain_10,
131
ep,
132
learner=learner.RandomPolicy(),
133
episode_count=eval_episode_count,
134
iteration_count=iteration_count,
135
epsilon=1.0, # purely random
136
render=False,
137
verbosity=Verbosity.Quiet,
138
title="Random search",
139
)
140
141
# %%
142
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
143
themodel = dqla.CyberBattleStateActionModel(ep)
144
p.plot_averaged_cummulative_rewards(
145
all_runs=[best_dqn_learning_run_10, random_run, dql_exploit_run],
146
title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"
147
f"State: {[f.name() for f in themodel.state_space.feature_selection]} "
148
f"({len(themodel.state_space.feature_selection)}\n"
149
f"Action: abstract_action ({themodel.action_space.flat_size()})",
150
)
151
152
153
# %%
154
# plot cumulative rewards for all episodes
155
p.plot_all_episodes(best_dqn_learning_run_10)
156
157
158
##################################################
159
# %%
160
best_dqn_4 = learner.epsilon_greedy_search(
161
cyberbattle_gym_env=cyberbattlechain_4,
162
environment_properties=ep,
163
learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.15, replay_memory_size=10000, target_update=5, batch_size=256, learning_rate=0.01),
164
episode_count=training_episode_count,
165
iteration_count=iteration_count,
166
epsilon=0.90,
167
render=False,
168
epsilon_exponential_decay=5000,
169
epsilon_minimum=0.10,
170
verbosity=Verbosity.Quiet,
171
title="DQL",
172
)
173
174
175
# %%
176
learner.transfer_learning_evaluation(
177
environment_properties=ep,
178
trained_learner=best_dqn_learning_run_10,
179
eval_env=cyberbattlechain_20,
180
eval_epsilon=0.0, # alternate with exploration to help generalization to bigger network
181
eval_episode_count=eval_episode_count,
182
iteration_count=iteration_count,
183
benchmark_policy=rca.CredentialCacheExploiter(),
184
benchmark_training_args={"epsilon": 0.90, "epsilon_exponential_decay": 10000, "epsilon_minimum": 0.10, "title": "Credential lookups (ϵ-greedy)"},
185
)
186
learner.transfer_learning_evaluation(
187
environment_properties=ep,
188
trained_learner=best_dqn_4,
189
eval_env=cyberbattlechain_10,
190
eval_epsilon=0.0, # exploit Q-matrix only
191
eval_episode_count=eval_episode_count,
192
iteration_count=iteration_count,
193
benchmark_policy=rca.CredentialCacheExploiter(),
194
benchmark_training_args={"epsilon": 0.90, "epsilon_exponential_decay": 10000, "epsilon_minimum": 0.10, "title": "Credential lookups (ϵ-greedy)"},
195
)
196
197
# %%
198
learner.transfer_learning_evaluation(
199
environment_properties=ep,
200
trained_learner=best_dqn_4,
201
eval_env=cyberbattlechain_20,
202
eval_epsilon=0.0, # exploit Q-matrix only
203
eval_episode_count=eval_episode_count,
204
iteration_count=iteration_count,
205
benchmark_policy=rca.CredentialCacheExploiter(),
206
benchmark_training_args={"epsilon": 0.90, "epsilon_exponential_decay": 10000, "epsilon_minimum": 0.10, "title": "Credential lookups (ϵ-greedy)"},
207
)
208
209