Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/agent_tabularqlearning.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Q-learning agent applied to chain network (notebook)
5
This notebooks can be run directly from VSCode, to generate a
6
traditional Jupyter Notebook to open in your browser
7
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
8
"""
9
10
# pylint: disable=invalid-name
11
12
from typing import NamedTuple, Optional, Tuple
13
import numpy as np
14
import logging
15
16
from cyberbattle._env import cyberbattle_env
17
from .agent_wrapper import EnvironmentBounds
18
from .agent_randomcredlookup import CredentialCacheExploiter
19
import cyberbattle.agents.baseline.agent_wrapper as w
20
import cyberbattle.agents.baseline.learner as learner
21
22
23
def random_argmax(array):
24
"""Just like `argmax` but if there are multiple elements with the max
25
return a random index to break ties instead of returning the first one."""
26
max_value = np.max(array)
27
max_index = np.where(array == max_value)[0]
28
29
if max_index.shape[0] > 1:
30
max_index = int(np.random.choice(max_index, size=1))
31
else:
32
max_index = int(max_index)
33
34
return max_value, max_index
35
36
37
def random_argtop_percentile(array: np.ndarray, percentile: float):
38
"""Just like `argmax` but if there are multiple elements with the max
39
return a random index to break ties instead of returning the first one."""
40
top_percentile = np.percentile(array, percentile)
41
indices = np.where(array >= top_percentile)[0]
42
43
if len(indices) == 0:
44
return random_argmax(array)
45
elif indices.shape[0] > 1:
46
max_index = int(np.random.choice(indices, size=1))
47
else:
48
max_index = int(indices)
49
50
return top_percentile, max_index
51
52
53
class QMatrix:
54
"""Q-Learning matrix for a given state and action space
55
state_space - Features defining the state space
56
action_space - Features defining the action space
57
qm - Optional: initialization values for the Q matrix
58
"""
59
60
# The Quality matrix
61
qm: np.ndarray
62
63
def __init__(
64
self,
65
name,
66
state_space: w.Feature,
67
action_space: w.Feature,
68
qm: Optional[np.ndarray] = None,
69
):
70
"""Initialize the Q-matrix"""
71
72
self.name = name
73
self.state_space = state_space
74
self.action_space = action_space
75
self.statedim = state_space.flat_size()
76
self.actiondim = action_space.flat_size()
77
self.qm = self.clear() if qm is None else qm
78
79
# error calculated for the last update to the Q-matrix
80
self.last_error = 0
81
82
def shape(self):
83
return (self.statedim, self.actiondim)
84
85
def clear(self):
86
"""Re-initialize the Q-matrix to 0"""
87
self.qm = np.zeros(shape=self.shape())
88
# self.qm = np.random.rand(*self.shape()) / 100
89
return self.qm
90
91
def print(self):
92
print(f"[{self.name}]\n" f"state: {self.state_space}\n" f"action: {self.action_space}\n" f"shape = {self.shape()}")
93
94
def update(
95
self,
96
current_state: int,
97
action: int,
98
next_state: int,
99
reward,
100
gamma,
101
learning_rate,
102
):
103
"""Update the Q matrix after taking `action` in state 'current_State'
104
and obtaining reward=R[current_state, action]"""
105
106
maxq_atnext, max_index = random_argmax(self.qm[next_state,])
107
108
# bellman equation for Q-learning
109
temporal_difference = reward + gamma * maxq_atnext - self.qm[current_state, action]
110
self.qm[current_state, action] += learning_rate * temporal_difference
111
112
# The loss is calculated using the squared difference between
113
# target Q-Value and predicted Q-Value
114
square_error = temporal_difference * temporal_difference
115
self.last_error = square_error
116
117
return self.qm[current_state, action]
118
119
def exploit(self, features, percentile) -> Tuple[int, float]:
120
"""exploit: leverage the Q-matrix.
121
Returns the expected Q value and the chosen action."""
122
expected_q, action = random_argtop_percentile(self.qm[features, :], percentile)
123
return int(action), float(expected_q)
124
125
126
class QLearnAttackSource(QMatrix):
127
"""Top-level Q matrix to pick the attack
128
State space: global state info
129
Action space: feature encodings of suggested nodes
130
"""
131
132
def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None):
133
self.ep = ep
134
135
self.state_space = w.HashEncoding(
136
ep,
137
[
138
# Feature_discovered_node_count(),
139
# Feature_discovered_credential_count(),
140
w.Feature_discovered_ports_sliding(ep),
141
w.Feature_discovered_nodeproperties_sliding(ep),
142
w.Feature_discovered_notowned_node_count(ep, 3),
143
],
144
5000,
145
) # should not be too small, pick something big to avoid collision
146
147
self.action_space = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep)])
148
149
super().__init__("attack_source", self.state_space, self.action_space, qm)
150
151
152
class QLearnBestAttackAtSource(QMatrix):
153
"""Top-level Q matrix to pick the attack from a pre-chosen source node
154
State space: feature encodings of suggested node states
155
Action space: a SimpleAbstract action
156
"""
157
158
def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None) -> None:
159
self.state_space = w.HashEncoding(
160
ep,
161
[
162
w.Feature_active_node_properties(ep),
163
w.Feature_active_node_age(ep),
164
# w.Feature_actions_tried_at_node(ep)
165
],
166
7000,
167
)
168
169
# NOTE: For debugging purpose it's convenient instead to use
170
# Ravel encoding for node properties
171
self.state_space_debugging = w.RavelEncoding(
172
ep,
173
[
174
w.HashEncoding(
175
ep,
176
[
177
# Feature_discovered_node_count(),
178
# Feature_discovered_credential_count(),
179
w.Feature_discovered_ports_sliding(ep),
180
w.Feature_discovered_nodeproperties_sliding(ep),
181
w.Feature_discovered_notowned_node_count(ep, 3),
182
],
183
100,
184
),
185
w.Feature_active_node_properties(ep),
186
],
187
)
188
189
self.action_space = w.AbstractAction(ep)
190
191
super().__init__("attack_at_source", self.state_space, self.action_space, qm)
192
193
194
# TODO: We should try scipy for sparse matrices and OpenBLAS (MKL Intel version of BLAS, faster than openBLAS) for numpy
195
196
197
# %%
198
class LossEval:
199
"""Loss evaluation for a Q-Learner,
200
learner -- The Q learner
201
"""
202
203
def __init__(self, qmatrix: QMatrix):
204
self.qmatrix = qmatrix
205
self.this_episode = []
206
self.all_episodes = []
207
208
def new_episode(self):
209
self.this_episode = []
210
211
def end_of_iteration(self, t, done):
212
self.this_episode.append(self.qmatrix.last_error)
213
214
def current_episode_loss(self):
215
return np.average(self.this_episode)
216
217
def end_of_episode(self, i_episode, t):
218
"""Average out the overall loss for this episode"""
219
self.all_episodes.append(self.current_episode_loss())
220
221
222
class ChosenActionMetadata(NamedTuple):
223
"""Additional information associated with the action chosen by the agent"""
224
225
Q_source_state: int
226
Q_source_expectedq: float
227
Q_attack_expectedq: float
228
source_node: int
229
source_node_encoding: int
230
abstract_action: np.int32
231
Q_attack_state: int
232
233
234
class QTabularLearner(learner.Learner):
235
"""Tabular Q-learning
236
237
Parameters
238
==========
239
gamma -- discount factor
240
241
learning_rate -- learning rate
242
243
ep -- environment global properties
244
245
trained -- another QTabularLearner that is pretrained to initialize the Q matrices from (referenced, not copied)
246
247
exploit_percentile -- (experimental) Randomly pick actions above this percentile in the Q-matrix.
248
Setting 100 gives the argmax as in standard Q-learning.
249
250
The idea is that a value less than 100 helps compensate for the
251
approximation made when updating the Q-matrix caused by
252
the abstraction of the action space (attack parameters are abstracted away
253
in the Q-matrix, and when an abstract action is picked, it
254
gets specialized via a random process.)
255
When running in non-learning mode (lr=0), setting this value too close to 100
256
may lead to get stuck, being more permissive (e.g. in the 80-90 range)
257
typically gives better results.
258
259
"""
260
261
def __init__(
262
self,
263
ep: EnvironmentBounds,
264
gamma: float,
265
learning_rate: float,
266
exploit_percentile: float,
267
trained=None, # : Optional[QTabularLearner]
268
):
269
if trained:
270
self.qsource = trained.qsource
271
self.qattack = trained.qattack
272
else:
273
self.qsource = QLearnAttackSource(ep)
274
self.qattack = QLearnBestAttackAtSource(ep)
275
276
self.loss_qsource = LossEval(self.qsource)
277
self.loss_qattack = LossEval(self.qattack)
278
self.gamma = gamma
279
self.learning_rate = learning_rate
280
self.exploit_percentile = exploit_percentile
281
self.credcache_policy = CredentialCacheExploiter()
282
283
def on_step(
284
self,
285
wrapped_env: w.AgentWrapper,
286
observation,
287
reward,
288
done,
289
truncated,
290
info,
291
action_metadata: ChosenActionMetadata,
292
):
293
agent_state = wrapped_env.state
294
295
# Update the top-level Q matrix for the state of the selected source node
296
after_toplevel_state = self.qsource.state_space.encode(agent_state)
297
self.qsource.update(
298
action_metadata.Q_source_state,
299
action_metadata.source_node_encoding,
300
after_toplevel_state,
301
reward,
302
self.gamma,
303
self.learning_rate,
304
)
305
306
# Update the second Q matrix for the abstract action chosen
307
qattack_state_after = self.qattack.state_space.encode_at(agent_state, action_metadata.source_node)
308
self.qattack.update(
309
action_metadata.Q_attack_state,
310
int(action_metadata.abstract_action),
311
qattack_state_after,
312
reward,
313
self.gamma,
314
self.learning_rate,
315
)
316
317
def end_of_iteration(self, t, done):
318
self.loss_qsource.end_of_iteration(t, done)
319
self.loss_qattack.end_of_iteration(t, done)
320
321
def end_of_episode(self, i_episode, t):
322
self.loss_qsource.end_of_episode(i_episode, t)
323
self.loss_qattack.end_of_episode(i_episode, t)
324
325
def loss_as_string(self):
326
return f"[loss_source={self.loss_qsource.current_episode_loss():0.3f}" f" loss_attack={self.loss_qattack.current_episode_loss():0.3f}]"
327
328
def new_episode(self):
329
self.loss_qsource.new_episode()
330
self.loss_qattack.new_episode()
331
332
def exploit(self, wrapped_env: w.AgentWrapper, observation):
333
agent_state = wrapped_env.state
334
335
qsource_state = self.qsource.state_space.encode(agent_state)
336
337
#############
338
# first, attempt to exploit the credential cache
339
# using the crecache_policy
340
action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)
341
if gym_action:
342
source_node = cyberbattle_env.sourcenode_of_action(gym_action)
343
return (
344
action_style,
345
gym_action,
346
ChosenActionMetadata(
347
Q_source_state=qsource_state,
348
Q_source_expectedq=-1,
349
Q_attack_expectedq=-1,
350
source_node=source_node,
351
source_node_encoding=self.qsource.action_space.encode_at(agent_state, source_node),
352
abstract_action=np.int32(self.qattack.action_space.abstract_from_gymaction(gym_action)),
353
Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node),
354
),
355
)
356
#############
357
358
# Pick action: pick random source state among the ones with the maximum Q-value
359
action_style = "exploit"
360
source_node_encoding, qsource_expectedq = self.qsource.exploit(qsource_state, percentile=100)
361
362
# Pick source node at random (owned and with the desired feature encoding)
363
potential_source_nodes = [from_node for from_node in w.owned_nodes(observation) if source_node_encoding == self.qsource.action_space.encode_at(agent_state, from_node)]
364
365
if len(potential_source_nodes) == 0:
366
logging.debug(f"No node with encoding {source_node_encoding}, fallback on explore")
367
# NOTE: we should make sure that it does not happen too often,
368
# the penalty should be much smaller than typical rewards, small nudge
369
# not a new feedback signal.
370
371
# Learn the lack of node availability
372
self.qsource.update(
373
qsource_state,
374
source_node_encoding,
375
qsource_state,
376
reward=0,
377
gamma=self.gamma,
378
learning_rate=self.learning_rate,
379
)
380
381
return "exploit-1->explore", None, None
382
else:
383
source_node = np.random.choice(potential_source_nodes)
384
385
qattack_state = self.qattack.state_space.encode_at(agent_state, source_node)
386
387
abstract_action, qattack_expectedq = self.qattack.exploit(qattack_state, percentile=self.exploit_percentile)
388
389
gym_action = self.qattack.action_space.specialize_to_gymaction(source_node, observation, np.int32(abstract_action))
390
391
assert int(abstract_action) < self.qattack.action_space.flat_size(), f"abstract_action={abstract_action} gym_action={gym_action}"
392
393
if gym_action and wrapped_env.env.is_action_valid(gym_action, observation["action_mask"]):
394
logging.debug(f" exploit gym_action={gym_action} source_node_encoding={source_node_encoding}")
395
return (
396
action_style,
397
gym_action,
398
ChosenActionMetadata(
399
Q_source_state=qsource_state,
400
Q_source_expectedq=qsource_expectedq,
401
Q_attack_expectedq=qsource_expectedq,
402
source_node=source_node,
403
source_node_encoding=source_node_encoding,
404
abstract_action=np.int32(abstract_action),
405
Q_attack_state=qattack_state,
406
),
407
)
408
else:
409
# NOTE: We should make the penalty reward smaller than
410
# the average/typical non-zero reward of the env (e.g. 1/1000 smaller)
411
# The idea of weighing the learning_rate when taking a chance is
412
# related to "Inverse propensity weighting"
413
414
# Learn the non-validity of the action
415
self.qsource.update(
416
qsource_state,
417
source_node_encoding,
418
qsource_state,
419
reward=0,
420
gamma=self.gamma,
421
learning_rate=self.learning_rate,
422
)
423
424
self.qattack.update(
425
qattack_state,
426
int(abstract_action),
427
qattack_state,
428
reward=0,
429
gamma=self.gamma,
430
learning_rate=self.learning_rate,
431
)
432
433
# fallback on random exploration
434
return (
435
("exploit[invalid]->explore" if gym_action else "exploit[undefined]->explore"),
436
None,
437
None,
438
)
439
440
def explore(self, wrapped_env: w.AgentWrapper):
441
agent_state = wrapped_env.state
442
gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
443
abstract_action = self.qattack.action_space.abstract_from_gymaction(gym_action)
444
445
assert int(abstract_action) < self.qattack.action_space.flat_size(), f"Q_attack_action={abstract_action} gym_action={gym_action}"
446
447
source_node = cyberbattle_env.sourcenode_of_action(gym_action)
448
449
return (
450
"explore",
451
gym_action,
452
ChosenActionMetadata(
453
Q_source_state=self.qsource.state_space.encode(agent_state),
454
Q_source_expectedq=-1,
455
Q_attack_expectedq=-1,
456
source_node=source_node,
457
source_node_encoding=self.qsource.action_space.encode_at(agent_state, source_node),
458
abstract_action=abstract_action,
459
Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node),
460
),
461
)
462
463
def stateaction_as_string(self, action_metadata) -> str:
464
return (
465
f"Qsource[state={action_metadata.Q_source_state} err={self.qsource.last_error:0.2f}"
466
f"Q={action_metadata.Q_source_expectedq:.2f}] "
467
f"Qattack[state={action_metadata.Q_attack_state} err={self.qattack.last_error:0.2f} "
468
f"Q={action_metadata.Q_attack_expectedq:.2f}] "
469
)
470
471
def parameters_as_string(self) -> str:
472
return f"γ={self.gamma}," f"learning_rate={self.learning_rate}," f"Q%={self.exploit_percentile}"
473
474
def all_parameters_as_string(self) -> str:
475
return (
476
f" dimension={self.qsource.state_space.flat_size()}x{self.qsource.action_space.flat_size()},"
477
f"{self.qattack.state_space.flat_size()}x{self.qattack.action_space.flat_size()}\n"
478
f"Q1={[f.name() for f in self.qsource.state_space.feature_selection]}"
479
f" -> {[f.name() for f in self.qsource.action_space.feature_selection]}\n"
480
f"Q2={[f.name() for f in self.qattack.state_space.feature_selection]} -> 'action'"
481
)
482
483