Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/learner.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Learner helpers and epsilon greedy search"""
5
6
import math
7
import sys
8
9
from .plotting import PlotTraining, plot_averaged_cummulative_rewards
10
from .agent_wrapper import (
11
AgentWrapper,
12
EnvironmentBounds,
13
Verbosity,
14
ActionTrackingStateAugmentation,
15
)
16
import logging
17
import numpy as np
18
from cyberbattle._env import cyberbattle_env
19
from typing import Tuple, Optional, TypedDict, List
20
import progressbar
21
import abc
22
23
24
class Learner(abc.ABC):
25
"""Interface to be implemented by an epsilon-greedy learner"""
26
27
def new_episode(self) -> None:
28
return None
29
30
def end_of_episode(self, i_episode, t) -> None:
31
return None
32
33
def end_of_iteration(self, t, done) -> None:
34
return None
35
36
@abc.abstractmethod
37
def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
38
"""Exploration function.
39
Returns (action_type, gym_action, action_metadata) where
40
action_metadata is a custom object that gets passed to the on_step callback function"""
41
raise NotImplementedError
42
43
@abc.abstractmethod
44
def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
45
"""Exploit function.
46
Returns (action_type, gym_action, action_metadata) where
47
action_metadata is a custom object that gets passed to the on_step callback function"""
48
raise NotImplementedError
49
50
@abc.abstractmethod
51
def on_step(
52
self,
53
wrapped_env: AgentWrapper,
54
observation,
55
reward,
56
done,
57
truncated,
58
info,
59
action_metadata,
60
) -> None:
61
raise NotImplementedError
62
63
def parameters_as_string(self) -> str:
64
return ""
65
66
def all_parameters_as_string(self) -> str:
67
return ""
68
69
def loss_as_string(self) -> str:
70
return ""
71
72
def stateaction_as_string(self, action_metadata) -> str:
73
return ""
74
75
76
class RandomPolicy(Learner):
77
"""A policy that does not learn and only explore"""
78
79
def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
80
gym_action = wrapped_env.env.sample_valid_action()
81
return "explore", gym_action, None
82
83
def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
84
raise NotImplementedError
85
86
def on_step(
87
self,
88
wrapped_env: AgentWrapper,
89
observation,
90
reward,
91
done,
92
truncated,
93
info,
94
action_metadata,
95
):
96
return None
97
98
99
Breakdown = TypedDict("Breakdown", {"local": int, "remote": int, "connect": int})
100
101
Outcomes = TypedDict("Outcomes", {"reward": Breakdown, "noreward": Breakdown})
102
103
Stats = TypedDict(
104
"Stats",
105
{"exploit": Outcomes, "explore": Outcomes, "exploit_deflected_to_explore": int},
106
)
107
108
TrainedLearner = TypedDict(
109
"TrainedLearner",
110
{
111
"all_episodes_rewards": List[List[float]],
112
"all_episodes_availability": List[List[float]],
113
"learner": Learner,
114
"trained_on": str,
115
"title": str,
116
},
117
)
118
119
120
def print_stats(stats):
121
"""Print learning statistics"""
122
123
def print_breakdown(stats, actiontype: str):
124
def ratio(kind: str) -> str:
125
x, y = (
126
stats[actiontype]["reward"][kind],
127
stats[actiontype]["noreward"][kind],
128
)
129
sum = x + y
130
if sum == 0:
131
return "NaN"
132
else:
133
return f"{(x / sum):.2f}"
134
135
def print_kind(kind: str):
136
print(f" {actiontype}-{kind}: {stats[actiontype]['reward'][kind]}/{stats[actiontype]['noreward'][kind]} " f"({ratio(kind)})")
137
138
print_kind("local")
139
print_kind("remote")
140
print_kind("connect")
141
142
print(" Breakdown [Reward/NoReward (Success rate)]")
143
print_breakdown(stats, "explore")
144
print_breakdown(stats, "exploit")
145
print(f" exploit deflected to exploration: {stats['exploit_deflected_to_explore']}")
146
147
148
def epsilon_greedy_search(
149
cyberbattle_gym_env: cyberbattle_env.CyberBattleEnv,
150
environment_properties: EnvironmentBounds,
151
learner: Learner,
152
title: str,
153
episode_count: int,
154
iteration_count: int,
155
epsilon: float,
156
epsilon_minimum=0.0,
157
epsilon_multdecay: Optional[float] = None,
158
epsilon_exponential_decay: Optional[int] = None,
159
render=True,
160
render_last_episode_rewards_to: Optional[str] = None,
161
verbosity: Verbosity = Verbosity.Normal,
162
plot_episodes_length=True,
163
) -> TrainedLearner:
164
"""Epsilon greedy search for CyberBattle gym environments
165
166
Parameters
167
==========
168
169
- cyberbattle_gym_env -- the CyberBattle environment to train on
170
171
- learner --- the policy learner/exploiter
172
173
- episode_count -- Number of training episodes
174
175
- iteration_count -- Maximum number of iterations in each episode
176
177
- epsilon -- explore vs exploit
178
- 0.0 to exploit the learnt policy only without exploration
179
- 1.0 to explore purely randomly
180
181
- epsilon_minimum -- epsilon decay clipped at this value.
182
Setting this value too close to 0 may leed the search to get stuck.
183
184
- epsilon_decay -- epsilon gets multiplied by this value after each episode
185
186
- epsilon_exponential_decay - if set use exponential decay. The bigger the value
187
is, the slower it takes to get from the initial `epsilon` to `epsilon_minimum`.
188
189
- verbosity -- verbosity of the `print` logging
190
191
- render -- render the environment interactively after each episode
192
193
- render_last_episode_rewards_to -- render the environment to the specified file path
194
with an index appended to it each time there is a positive reward
195
for the last episode only
196
197
- plot_episodes_length -- Plot the graph showing total number of steps by episode
198
at th end of the search.
199
200
Note on convergence
201
===================
202
203
Setting 'minimum_espilon' to 0 with an exponential decay <1
204
makes the learning converge quickly (loss function getting to 0),
205
but that's just a forced convergence, however, since when
206
epsilon approaches 0, only the q-values that were explored so
207
far get updated and so only that subset of cells from
208
the Q-matrix converges.
209
210
"""
211
212
print(
213
f"###### {title}\n"
214
f"Learning with: episode_count={episode_count},"
215
f"iteration_count={iteration_count},"
216
f"ϵ={epsilon},"
217
f"ϵ_min={epsilon_minimum}, "
218
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else "")
219
+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else "")
220
+ f"{learner.parameters_as_string()}"
221
)
222
223
initial_epsilon = epsilon
224
225
all_episodes_rewards = []
226
all_episodes_availability = []
227
228
o, _ = cyberbattle_gym_env.reset()
229
wrapped_env = AgentWrapper(
230
cyberbattle_gym_env,
231
ActionTrackingStateAugmentation(environment_properties, o),
232
)
233
steps_done = 0
234
plot_title = (
235
f"{title} (epochs={episode_count}, ϵ={initial_epsilon}, ϵ_min={epsilon_minimum},"
236
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else "")
237
+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else "")
238
+ learner.parameters_as_string()
239
)
240
plottraining = PlotTraining(title=plot_title, render_each_episode=render)
241
242
render_file_index = 1
243
244
for i_episode in range(1, episode_count + 1):
245
print(f" ## Episode: {i_episode}/{episode_count} '{title}' " f"ϵ={epsilon:.4f}, " f"{learner.parameters_as_string()}")
246
247
observation, _ = wrapped_env.reset()
248
total_reward = 0.0
249
all_rewards = []
250
all_availability = []
251
learner.new_episode()
252
253
stats = Stats(
254
exploit=Outcomes(
255
reward=Breakdown(local=0, remote=0, connect=0),
256
noreward=Breakdown(local=0, remote=0, connect=0),
257
),
258
explore=Outcomes(
259
reward=Breakdown(local=0, remote=0, connect=0),
260
noreward=Breakdown(local=0, remote=0, connect=0),
261
),
262
exploit_deflected_to_explore=0,
263
)
264
265
episode_ended_at = None
266
sys.stdout.flush()
267
268
bar = progressbar.ProgressBar(
269
widgets=[
270
"Episode ",
271
f"{i_episode}",
272
"|Iteration ",
273
progressbar.Counter(),
274
"|",
275
progressbar.Variable(name="reward", width=6, precision=10),
276
"|",
277
progressbar.Variable(name="last_reward_at", width=4),
278
"|",
279
progressbar.Timer(),
280
progressbar.Bar(),
281
],
282
redirect_stdout=False,
283
)
284
285
for t in bar(range(1, 1 + iteration_count)):
286
if epsilon_exponential_decay:
287
epsilon = epsilon_minimum + math.exp(-1.0 * steps_done / epsilon_exponential_decay) * (initial_epsilon - epsilon_minimum)
288
289
steps_done += 1
290
291
x = np.random.rand()
292
if x <= epsilon:
293
action_style, gym_action, action_metadata = learner.explore(wrapped_env)
294
else:
295
action_style, gym_action, action_metadata = learner.exploit(wrapped_env, observation)
296
if not gym_action:
297
stats["exploit_deflected_to_explore"] += 1
298
_, gym_action, action_metadata = learner.explore(wrapped_env)
299
300
# Take the step
301
logging.debug(f"gym_action={gym_action}, action_metadata={action_metadata}")
302
observation, reward, done, truncated, info = wrapped_env.step(gym_action)
303
304
action_type = "exploit" if action_style == "exploit" else "explore"
305
outcome = "reward" if reward > 0 else "noreward"
306
if "local_vulnerability" in gym_action:
307
stats[action_type][outcome]["local"] += 1
308
elif "remote_vulnerability" in gym_action:
309
stats[action_type][outcome]["remote"] += 1
310
else:
311
stats[action_type][outcome]["connect"] += 1
312
313
learner.on_step(wrapped_env, observation, reward, done, truncated, info, action_metadata)
314
assert np.shape(reward) == ()
315
316
all_rewards.append(reward)
317
all_availability.append(info["network_availability"])
318
total_reward += reward
319
bar.update(t, reward=total_reward)
320
if reward > 0:
321
bar.update(t, last_reward_at=t)
322
323
if verbosity == Verbosity.Verbose or (verbosity == Verbosity.Normal and reward > 0):
324
sign = ["-", "+"][reward > 0]
325
326
print(
327
f" {sign} t={t} {action_style} r={reward} cum_reward:{total_reward} "
328
f"a={action_metadata}-{gym_action} "
329
f"creds={len(observation['credential_cache_matrix'])} "
330
f" {learner.stateaction_as_string(action_metadata)}"
331
)
332
333
if i_episode == episode_count and render_last_episode_rewards_to is not None and reward > 0:
334
fig = cyberbattle_gym_env.render_as_fig()
335
fig.write_image(f"{render_last_episode_rewards_to}-e{i_episode}-{render_file_index}.png")
336
render_file_index += 1
337
338
learner.end_of_iteration(t, done)
339
340
if done:
341
episode_ended_at = t
342
bar.finish(dirty=True)
343
break
344
345
sys.stdout.flush()
346
347
loss_string = learner.loss_as_string()
348
if loss_string:
349
loss_string = "loss={loss_string}"
350
351
if episode_ended_at:
352
print(f" Episode {i_episode} ended at t={episode_ended_at} {loss_string}")
353
else:
354
print(f" Episode {i_episode} stopped at t={iteration_count} {loss_string}")
355
356
print_stats(stats)
357
358
all_episodes_rewards.append(all_rewards)
359
all_episodes_availability.append(all_availability)
360
361
length = episode_ended_at if episode_ended_at else iteration_count
362
learner.end_of_episode(i_episode=i_episode, t=length)
363
if plot_episodes_length:
364
plottraining.episode_done(length)
365
if render:
366
wrapped_env.render()
367
368
if epsilon_multdecay:
369
epsilon = max(epsilon_minimum, epsilon * epsilon_multdecay)
370
371
wrapped_env.close()
372
print("simulation ended")
373
if plot_episodes_length:
374
plottraining.plot_end()
375
376
return TrainedLearner(
377
all_episodes_rewards=all_episodes_rewards,
378
all_episodes_availability=all_episodes_availability,
379
learner=learner,
380
trained_on=cyberbattle_gym_env.name,
381
title=plot_title,
382
)
383
384
385
def transfer_learning_evaluation(
386
environment_properties: EnvironmentBounds,
387
trained_learner: TrainedLearner,
388
eval_env: cyberbattle_env.CyberBattleEnv,
389
eval_epsilon: float,
390
eval_episode_count: int,
391
iteration_count: int,
392
benchmark_policy: Learner = RandomPolicy(),
393
benchmark_training_args=dict(title="Benchmark", epsilon=1.0),
394
):
395
"""Evaluated a trained agent on another environment of different size"""
396
397
eval_oneshot_all = epsilon_greedy_search(
398
eval_env,
399
environment_properties,
400
learner=trained_learner["learner"],
401
episode_count=eval_episode_count, # one shot from learnt Q matric
402
iteration_count=iteration_count,
403
epsilon=eval_epsilon,
404
render=False,
405
verbosity=Verbosity.Quiet,
406
title=f"One shot on {eval_env.name} - Trained on {trained_learner['trained_on']}",
407
)
408
409
eval_random = epsilon_greedy_search(
410
eval_env,
411
environment_properties,
412
learner=benchmark_policy,
413
episode_count=eval_episode_count,
414
iteration_count=iteration_count,
415
render=False,
416
verbosity=Verbosity.Quiet,
417
**benchmark_training_args,
418
)
419
420
plot_averaged_cummulative_rewards(
421
all_runs=[eval_oneshot_all, eval_random],
422
title=f"Transfer learning {trained_learner['trained_on']}->{eval_env.name} "
423
f'-- max_nodes={environment_properties.maximum_node_count}, '
424
f'episodes={eval_episode_count},\n'
425
f"{trained_learner['learner'].all_parameters_as_string()}",
426
)
427
428