Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/agent_wrapper.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Agent wrapper for CyberBattle envrionments exposing additional
5
features extracted from the environment observations"""
6
7
from abc import abstractmethod
8
from cyberbattle._env.cyberbattle_env import EnvironmentBounds
9
from typing import Optional, List, Tuple, overload
10
import enum
11
import numpy as np
12
from gymnasium import spaces, Wrapper
13
from numpy import ndarray
14
import cyberbattle._env.cyberbattle_env as cyberbattle_env
15
import logging
16
17
18
class StateAugmentation:
19
"""Default agent state augmentation, consisting of the gym environment
20
observation itself and nothing more."""
21
22
def __init__(self, observation: cyberbattle_env.Observation):
23
self.observation = observation
24
25
def on_step(
26
self,
27
action: cyberbattle_env.Action,
28
reward: float,
29
truncated: bool,
30
done: bool,
31
observation: cyberbattle_env.Observation,
32
):
33
self.observation = observation
34
35
def on_reset(self, observation: cyberbattle_env.Observation):
36
self.observation = observation
37
38
39
# Abstract class for a feature (either global or node-specific)
40
class Feature(spaces.MultiDiscrete):
41
"""
42
Feature consisting of multiple discrete dimensions.
43
Parameters:
44
nvec: is a vector defining the number of possible values
45
for each discrete space.
46
"""
47
48
def __init__(self, env_properties: EnvironmentBounds, nvec):
49
self.env_properties = env_properties
50
super().__init__(nvec)
51
52
def flat_size(self):
53
return np.prod(self.nvec, dtype=int)
54
55
def name(self):
56
"""Return the name of the feature"""
57
p = len(type(Feature(self.env_properties, [])).__name__) + 1
58
return type(self).__name__[p:]
59
60
def pretty_print(self, v):
61
return v
62
63
@abstractmethod
64
def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:
65
"""Compute the current value of a feature value at
66
the current observation and specific node"""
67
raise NotImplementedError
68
69
70
class NodeFeature(Feature):
71
"""
72
Feature consisting of multiple discrete dimensions at a specific node.
73
"""
74
75
@abstractmethod
76
def get_at(self, a: StateAugmentation, node: int) -> np.ndarray:
77
"""Compute the current value of a feature value at
78
the current observation and specific node"""
79
raise NotImplementedError
80
81
def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:
82
assert node is not None, "feature only valid in the context of a node"
83
return self.get_at(a, node)
84
85
86
class GlobalFeature(Feature):
87
"""
88
Feature consisting of multiple discrete dimensions at the global level.
89
"""
90
91
@abstractmethod
92
def get_global(self, a: StateAugmentation) -> np.ndarray:
93
"""Compute the current value of a feature value at
94
the current observation"""
95
raise NotImplementedError
96
97
def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:
98
assert node is None, "feature only valid in the context of a node"
99
return self.get_global(a)
100
101
# @staticmethod
102
# def get_feature_value(
103
# f: Union[NodeFeature, GlobalFeature], a: SA_T, node: Optional[int]
104
# ):
105
# """Return the feature value at the current observation and specific node"""
106
# if isinstance(f, NodeFeature):
107
# assert node is not None, "feature only valid in the context of a node"
108
# return f.get(a, node)
109
# elif isinstance(f, GlobalFeature):
110
# assert node is None, "feature only valid in the context of a node"
111
# return f.get(a)
112
113
114
class Feature_active_node_properties(NodeFeature):
115
"""Bitmask of all properties set for the active node"""
116
117
def __init__(self, p: EnvironmentBounds):
118
super().__init__(p, [2] * p.property_count)
119
120
def get_at(self, a: StateAugmentation, node) -> ndarray:
121
assert node is not None, "feature only valid in the context of a node"
122
123
node_prop = a.observation["discovered_nodes_properties"]
124
125
# list of all properties set/unset on the node
126
assert node < len(node_prop), f"invalid node index {node} (not discovered yet)"
127
128
# Remap to get rid of the unknown value (2):
129
# 1->1, 0->0, 2->0
130
remapped = np.array(node_prop[node] % 2, dtype=np.int_)
131
return remapped
132
133
134
class Feature_active_node_age(NodeFeature):
135
"""How recently was this node discovered?
136
(measured by reverse position in the list of discovered nodes)"""
137
138
def __init__(self, p: EnvironmentBounds):
139
super().__init__(p, [p.maximum_node_count])
140
141
def get_at(self, a: StateAugmentation, node) -> ndarray:
142
assert node is not None, "feature only valid in the context of a node"
143
144
discovered_node_count = a.observation["discovered_node_count"]
145
146
assert node < discovered_node_count, f"invalid node index {node} (not discovered yet)"
147
148
return np.array([discovered_node_count - node - 1], dtype=np.int_)
149
150
151
class Feature_active_node_id(NodeFeature):
152
"""Return the node id itself"""
153
154
def __init__(self, p: EnvironmentBounds):
155
super().__init__(p, [p.maximum_node_count] * 1)
156
157
def get_at(self, a: StateAugmentation, node) -> ndarray:
158
return np.array([node], dtype=np.int_)
159
160
161
class Feature_discovered_nodeproperties_sliding(GlobalFeature):
162
"""Bitmask indicating node properties seen in last few cache entries"""
163
164
window_size = 3
165
166
def __init__(self, p: EnvironmentBounds):
167
super().__init__(p, [2] * p.property_count)
168
169
def get_global(self, a: StateAugmentation) -> ndarray:
170
n = a.observation["discovered_node_count"]
171
node_prop = a.observation["discovered_nodes_properties"][:n]
172
173
# keep last window of entries
174
node_prop_window = node_prop[-self.window_size :, :]
175
176
# Remap to get rid of the unknown value (2)
177
node_prop_window_remapped = np.int32(node_prop_window % 2)
178
179
countby = np.sum(node_prop_window_remapped, axis=0)
180
181
bitmask = (countby > 0) * 1
182
return bitmask
183
184
185
class Feature_discovered_ports(GlobalFeature):
186
"""Bitmask vector indicating each port seen so far in discovered credentials"""
187
188
def __init__(self, p: EnvironmentBounds):
189
super().__init__(p, [2] * p.port_count)
190
191
def get_global(self, a: StateAugmentation):
192
n = a.observation["credential_cache_length"]
193
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
194
if n > 0:
195
ccm = np.array(a.observation["credential_cache_matrix"])[:n]
196
known_credports[np.int32(ccm[:, 1])] = 1
197
return known_credports
198
199
200
class Feature_discovered_ports_sliding(GlobalFeature):
201
"""Bitmask indicating port seen in last few cache entries"""
202
203
window_size = 3
204
205
def __init__(self, p: EnvironmentBounds):
206
super().__init__(p, [2] * p.port_count)
207
208
def get_global(self, a: StateAugmentation):
209
known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)
210
n = a.observation["credential_cache_length"]
211
if n > 0:
212
ccm = np.array(a.observation["credential_cache_matrix"])[:n]
213
known_credports[np.int32(ccm[-self.window_size :, 1])] = 1
214
return known_credports
215
216
217
class Feature_discovered_ports_counts(GlobalFeature):
218
"""Count of each port seen so far in discovered credentials"""
219
220
def __init__(self, p: EnvironmentBounds):
221
super().__init__(p, [p.maximum_total_credentials + 1] * p.port_count)
222
223
def get_global(self, a: StateAugmentation):
224
n = a.observation["credential_cache_length"]
225
if n > 0:
226
ccm = np.array(a.observation["credential_cache_matrix"])[:n]
227
ports = np.int32(ccm[:, 1])
228
else:
229
ports = np.zeros(0)
230
return np.bincount(ports, minlength=self.env_properties.port_count)
231
232
233
class Feature_discovered_credential_count(GlobalFeature):
234
"""number of credentials discovered so far"""
235
236
def __init__(self, p: EnvironmentBounds):
237
super().__init__(p, [p.maximum_total_credentials + 1])
238
239
def get_global(self, a: StateAugmentation):
240
n = a.observation["credential_cache_length"]
241
return np.array([n], dtype=np.int_)
242
243
244
class Feature_discovered_node_count(GlobalFeature):
245
"""number of nodes discovered so far"""
246
247
def __init__(self, p: EnvironmentBounds):
248
super().__init__(p, [p.maximum_node_count + 1])
249
250
def get_global(self, a: StateAugmentation):
251
return np.array([a.observation["discovered_node_count"]], dtype=np.int_)
252
253
254
class Feature_discovered_notowned_node_count(GlobalFeature):
255
"""number of nodes discovered that are not owned yet (optionally clipped)"""
256
257
def __init__(self, p: EnvironmentBounds, clip: Optional[int]):
258
self.clip = np.int32(clip or p.maximum_node_count)
259
super().__init__(p, [self.clip + 1])
260
261
def get_global(self, a: StateAugmentation):
262
discovered = a.observation["discovered_node_count"]
263
node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])
264
# here we assume that a node is owned just if all its properties are known
265
owned = np.count_nonzero(np.all(node_props != 2, axis=1))
266
diff = np.int32(discovered - owned)
267
return np.array( [np.min((diff, self.clip))], dtype=np.int32)
268
269
270
class Feature_owned_node_count(GlobalFeature):
271
"""number of owned nodes so far"""
272
273
def __init__(self, p: EnvironmentBounds):
274
super().__init__(p, [p.maximum_node_count + 1])
275
276
def get_global(self, a: StateAugmentation):
277
levels = a.observation["nodes_privilegelevel"]
278
owned_nodes_indices = np.where(levels > 0)[0]
279
return np.array([len(owned_nodes_indices)], dtype=np.int_)
280
281
282
class ConcatFeatures(Feature):
283
"""Concatenate a list of features into a single feature
284
Parameters:
285
feature_selection - a selection of features to combine
286
"""
287
288
def __init__(
289
self,
290
p: EnvironmentBounds,
291
feature_selection: List[Feature],
292
):
293
self.feature_selection = feature_selection
294
self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])
295
super().__init__(p, [self.dim_sizes])
296
297
def pretty_print(self, v):
298
return v
299
300
def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:
301
"""Return the feature vector"""
302
feature_vector = [f.get(a, node) for f in self.feature_selection]
303
304
return np.concatenate(feature_vector)
305
306
307
class FeatureEncoder(Feature):
308
"""Encode a list of features as a unique index"""
309
310
feature_selection: List[Feature]
311
312
def vector_to_index(self, feature_vector: np.ndarray) -> int:
313
raise NotImplementedError
314
315
def feature_vector_of_observation_at(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray:
316
"""Return the current feature vector"""
317
feature_vector = [f.get(a, node) for f in self.feature_selection]
318
# print(f'feature_vector={feature_vector} self.feature_selection={self.feature_selection}')
319
return np.concatenate(feature_vector)
320
321
def feature_vector_of_observation(self, a: StateAugmentation):
322
return self.feature_vector_of_observation_at(a, None)
323
324
def encode(self, a: StateAugmentation, node=None) -> int:
325
"""Return the index encoding of the feature"""
326
feature_vector_concat = self.feature_vector_of_observation_at(a, node)
327
return self.vector_to_index(feature_vector_concat)
328
329
def encode_at(self, a: StateAugmentation, node: int) -> int:
330
"""Return the current feature vector encoding with a node context"""
331
feature_vector_concat = self.feature_vector_of_observation_at(a, node)
332
return self.vector_to_index(feature_vector_concat)
333
334
def name(self):
335
"""Return a name for the feature encoding"""
336
n = ", ".join([f.name() for f in self.feature_selection])
337
return f"[{n}]"
338
339
340
class HashEncoding(FeatureEncoder):
341
"""Feature defined as a hash of another feature
342
Parameters:
343
feature_selection: a selection of features to combine
344
hash_dim: dimension after hashing with hash(str(feature_vector)) or -1 for no hashing
345
"""
346
347
def __init__(
348
self,
349
p: EnvironmentBounds,
350
feature_selection: List[Feature],
351
hash_size: int,
352
):
353
self.feature_selection = feature_selection
354
self.hash_size = hash_size
355
super().__init__(p, [hash_size])
356
357
def flat_size(self):
358
return self.hash_size
359
360
def vector_to_index(self, feature_vector) -> int:
361
"""Hash the state vector"""
362
return hash(str(feature_vector)) % self.hash_size
363
364
def pretty_print(self, v):
365
return f"#{v}"
366
367
368
class RavelEncoding(FeatureEncoder):
369
"""Combine a set of features into a single feature with a unique index
370
(calculated by raveling the original indices)
371
Parameters:
372
feature_selection - a selection of features to combine
373
"""
374
375
def __init__(
376
self,
377
p: EnvironmentBounds,
378
feature_selection: List[Feature],
379
):
380
self.feature_selection = feature_selection
381
self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])
382
self.ravelled_size: np.int64 = np.prod(self.dim_sizes)
383
assert np.shape(self.ravelled_size) == (), f"! {np.shape(self.ravelled_size)}"
384
super().__init__(p, [self.ravelled_size])
385
386
def vector_to_index(self, feature_vector) -> int:
387
assert len(self.dim_sizes) == len(feature_vector), f"feature vector of size {len(feature_vector)}, " f"expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}"
388
index_intp = np.ravel_multi_index(list(feature_vector), list(self.dim_sizes))
389
index = index_intp.item()
390
assert index < self.ravelled_size, f"feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) " f"-> index={index}, max_index={self.ravelled_size-1})"
391
return index
392
393
def unravel_index(self, index) -> Tuple:
394
return np.unravel_index(index, self.dim_sizes)
395
396
def pretty_print(self, v):
397
return self.unravel_index(v)
398
399
400
def owned_nodes(observation):
401
"""Return the list of owned nodes"""
402
return np.nonzero(observation["nodes_privilegelevel"])[0]
403
404
405
def discovered_nodes_notowned(observation):
406
"""Return the list of discovered nodes that are not owned yet"""
407
return np.nonzero(observation["nodes_privilegelevel"] == 0)[0]
408
409
410
class AbstractAction(Feature):
411
"""An abstraction of the gym state space that reduces
412
the space dimension for learning use to just
413
- local_attack(vulnid) (source_node provided)
414
- remote_attack(vulnid) (source_node provided, target_node forgotten)
415
- connect(port) (source_node provided, target_node forgotten, credentials infered from cache)
416
"""
417
418
def __init__(self, p: EnvironmentBounds):
419
self.n_local_actions = p.local_attacks_count
420
self.n_remote_actions = p.remote_attacks_count
421
self.n_connect_actions = p.port_count
422
self.n_actions = self.n_local_actions + self.n_remote_actions + self.n_connect_actions
423
super().__init__(p, [self.n_actions])
424
425
def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_action_index: np.int32) -> Optional[cyberbattle_env.Action]:
426
"""Specialize an abstract "q"-action into a gym action.
427
Return an adjustement weight (1.0 if the choice was deterministic, 1/n if a choice was made out of n)
428
and the gym action"""
429
430
abstract_action_index_int = int(abstract_action_index)
431
432
discovered_nodes_count = observation["discovered_node_count"]
433
434
if abstract_action_index_int < self.n_local_actions:
435
vuln = abstract_action_index_int
436
return {"local_vulnerability": np.array([source_node, vuln])}
437
438
abstract_action_index_int -= self.n_local_actions
439
if abstract_action_index_int < self.n_remote_actions:
440
vuln = abstract_action_index_int
441
442
if discovered_nodes_count <= 1:
443
return None
444
445
# NOTE: We can do better here than random pick: ultimately this
446
# should be learnt from target node properties
447
448
# pick any node from the discovered ones
449
# excluding the source node itself
450
target = (source_node + 1 + np.random.choice(discovered_nodes_count - 1)) % discovered_nodes_count
451
452
return {"remote_vulnerability": np.array([source_node, target, vuln])}
453
454
abstract_action_index_int -= self.n_remote_actions
455
port = np.int32(abstract_action_index_int)
456
457
n_discovered_creds = observation["credential_cache_length"]
458
if n_discovered_creds <= 0:
459
# no credential available in the cache: cannot poduce a valid connect action
460
return None
461
discovered_credentials = np.array(observation["credential_cache_matrix"])[:n_discovered_creds]
462
463
nodes_not_owned = discovered_nodes_notowned(observation)
464
465
# Pick a matching cred from the discovered_cred matrix
466
# (at random if more than one exist for this target port)
467
match_port = discovered_credentials[:, 1] == port
468
match_port_indices = np.where(match_port)[0]
469
470
credential_indices_choices = [c for c in match_port_indices if discovered_credentials[c, 0] in nodes_not_owned]
471
472
if credential_indices_choices:
473
logging.debug("found matching cred in the credential cache")
474
else:
475
logging.debug("no cred matching requested port, trying instead creds used to access other ports")
476
credential_indices_choices = [i for (i, n) in enumerate(discovered_credentials[:, 0]) if n in nodes_not_owned]
477
478
if credential_indices_choices:
479
logging.debug("found cred in the credential cache without matching port name")
480
else:
481
logging.debug("no cred to use from the credential cache")
482
return None
483
484
cred = np.int32(np.random.choice(credential_indices_choices))
485
target = np.int32(discovered_credentials[cred, 0])
486
return {"connect": np.array([source_node, target, port, cred], dtype=np.int32)}
487
488
def abstract_from_gymaction(self, gym_action: cyberbattle_env.Action) -> np.int32:
489
"""Abstract a gym action into an action to be index in the Q-matrix"""
490
if "local_vulnerability" in gym_action:
491
return gym_action["local_vulnerability"][1]
492
elif "remote_vulnerability" in gym_action:
493
r = gym_action["remote_vulnerability"]
494
return self.n_local_actions + r[2]
495
496
assert "connect" in gym_action
497
c = gym_action["connect"]
498
499
a = self.n_local_actions + self.n_remote_actions + c[2]
500
assert a < self.n_actions
501
return np.int32(a)
502
503
504
class ActionTrackingStateAugmentation(StateAugmentation):
505
"""An agent state augmentation consisting of
506
the environment observation augmented with the following dynamic information:
507
- success_action_count: count of action taken and succeeded at the current node
508
- failed_action_count: count of action taken and failed at the current node
509
"""
510
511
def __init__(self, p: EnvironmentBounds, observation: cyberbattle_env.Observation):
512
self.aa = AbstractAction(p)
513
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
514
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
515
self.env_properties = p
516
super().__init__(observation)
517
518
def on_step(
519
self,
520
action: cyberbattle_env.Action,
521
reward: float,
522
truncated,
523
done: bool,
524
observation: cyberbattle_env.Observation,
525
):
526
node = cyberbattle_env.sourcenode_of_action(action)
527
abstract_action = self.aa.abstract_from_gymaction(action)
528
if reward > 0:
529
self.success_action_count[node, abstract_action] += 1
530
else:
531
self.failed_action_count[node, abstract_action] += 1
532
super().on_step(action, reward, done, truncated, observation)
533
534
def on_reset(self, observation: cyberbattle_env.Observation):
535
p = self.env_properties
536
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
537
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
538
super().on_reset(observation)
539
540
541
class Feature_actions_tried_at_node(NodeFeature):
542
"""A bit mask indicating which actions were already tried
543
a the current node: 0 no tried, 1 tried"""
544
545
def __init__(self, p: EnvironmentBounds):
546
super().__init__(p, [2] * AbstractAction(p).n_actions)
547
548
@overload
549
def get_at(self, a: ActionTrackingStateAugmentation, node: int): ...
550
551
@overload
552
def get_at(self, a: StateAugmentation, node: int): ...
553
554
def get_at(self, a: StateAugmentation, node: int):
555
assert node is not None, "feature only valid in the context of a node"
556
assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type"
557
return np.array(
558
((a.failed_action_count[node, :] + a.success_action_count[node, :]) != 0) * 1,
559
dtype=np.int_,
560
)
561
562
563
class Feature_success_actions_at_node(NodeFeature):
564
"""number of time each action succeeded at a given node"""
565
566
max_action_count = 100
567
568
def __init__(self, p: EnvironmentBounds):
569
super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)
570
571
@overload
572
def get_at(self, a: ActionTrackingStateAugmentation, node: int): ...
573
574
@overload
575
def get_at(self, a: StateAugmentation, node: int): ...
576
577
def get_at(self, a: StateAugmentation, node: int):
578
assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type"
579
580
return np.minimum(a.success_action_count[node, :], self.max_action_count - 1)
581
582
583
class Feature_failed_actions_at_node(NodeFeature):
584
"""number of time each action failed at a given node"""
585
586
max_action_count = 100
587
588
def __init__(self, p: EnvironmentBounds):
589
super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)
590
591
def get_at(self, a: StateAugmentation, node: int):
592
assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type"
593
return np.minimum(a.failed_action_count[node, :], self.max_action_count - 1)
594
595
596
class Verbosity(enum.Enum):
597
"""Verbosity of the learning function"""
598
599
Quiet = 0
600
Normal = 1
601
Verbose = 2
602
603
604
class AgentWrapper(Wrapper):
605
"""Gym wrapper to update the agent state on every step"""
606
607
def __init__(self, env: cyberbattle_env.CyberBattleEnv, state: StateAugmentation):
608
super().__init__(env)
609
self.env = env
610
self.state = state
611
612
def step(self, action: cyberbattle_env.Action): # type: ignore
613
observation, reward, done, truncated, info = self.env.step(action)
614
self.state.on_step(action, reward, done, truncated, observation)
615
return observation, reward, done, truncated, info
616
617
def reset(self, **kwargs):
618
observation, info = self.env.reset(**kwargs)
619
self.state.on_reset(observation)
620
return observation, info
621
622