Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/cyberbattle_env.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Anatares OpenGym Environment"""
5
6
import time
7
import copy
8
import logging
9
import networkx
10
from networkx import convert_matrix
11
from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict, cast
12
13
from gymnasium import spaces, Env
14
from gymnasium.utils import seeding
15
16
import numpy
17
18
from plotly.graph_objects import Scatter # type: ignore
19
from plotly.subplots import make_subplots # type: ignore
20
21
from cyberbattle._env.defender import DefenderAgent
22
from cyberbattle.simulation.model import PortName, PrivilegeLevel
23
from ..simulation import commandcontrol, model, actions
24
from .discriminatedunion import DiscriminatedUnion
25
import numpy as np
26
27
LOGGER = logging.getLogger(__name__)
28
29
# Used to allocate a discrete space value representing a field that
30
# is 'Not Applicable' (of value 0 by convention)
31
NA = 1
32
33
# Value defining an unused space slot
34
UNUSED_SLOT = numpy.int32(0)
35
# Value defining a used space slot
36
USED_SLOT = numpy.int32(1)
37
38
39
# The type of a sample from the Action space
40
Action = TypedDict(
41
"Action",
42
{
43
"local_vulnerability": numpy.ndarray,
44
# adding the generic type causes runtime
45
# TypeError `'type' object is not subscriptable'`
46
"remote_vulnerability": numpy.ndarray,
47
"connect": numpy.ndarray,
48
},
49
total=False,
50
)
51
52
# Type of a sample from the ActionMask space
53
ActionMask = TypedDict(
54
"ActionMask",
55
{
56
"local_vulnerability": numpy.ndarray,
57
"remote_vulnerability": numpy.ndarray,
58
"connect": numpy.ndarray,
59
},
60
)
61
62
# Type of a sample from the Observation space
63
Observation = TypedDict(
64
"Observation",
65
{
66
# ---------------------------------------------------------
67
# Outcome of the action just executed
68
# ---------------------------------------------------------
69
# number of new nodes discovered
70
"newly_discovered_nodes_count": numpy.int32,
71
# whether a lateral move was just performed
72
"lateral_move": numpy.int32,
73
# whether customer data were just discovered
74
"customer_data_found": numpy.int32,
75
# 0 if there were no probing attempt
76
# 1 if an attempted probing failed
77
# 2 if an attempted probing succeeded
78
"probe_result": numpy.int32,
79
# whether an escalation was completed and to which level
80
"escalation": numpy.int32,
81
# credentials that were just discovered after executing an action
82
"leaked_credentials": Tuple[numpy.ndarray, ...], # type: ignore
83
# bitmask indicating which action are valid in the current state
84
"action_mask": ActionMask,
85
# ---------------------------------------------------------
86
# State information aggregated over all actions executed so far
87
# ---------------------------------------------------------
88
# size of the credential stack (number of tuples in `credential_cache_matrix` that are not zeros)
89
"credential_cache_length": int,
90
# total nodes discovered so far
91
"discovered_node_count": int,
92
# Matrix of properties for all the discovered nodes
93
"discovered_nodes_properties": numpy.ndarray,
94
# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)
95
"nodes_privilegelevel": numpy.ndarray,
96
# Tuple encoding of the credential cache matrix.
97
# It consists of `bounds.maximum_total_credentials` tuples
98
# of numpy array of shape (2)
99
# where only the first `credential_cache_length` tuples are populated.
100
#
101
# Each tuple represent a discovered credential,
102
# the credential index is given by its tuple index (i.e., order of discovery)
103
# Each tuple is of the form: (target_node_discover_index, port_index)
104
"credential_cache_matrix": Tuple[numpy.ndarray, ...],
105
# ---------------------------------------------------------
106
# Raw information fields coming from the simulation environment
107
# that are not encoded as gym spaces (were previously in the 'info' field)
108
# ---------------------------------------------------------
109
# Mapping node index to internal IDs of all nodes discovered so far.
110
# The external node index used by the agent to refer to a node
111
# is defined as the index of the node in this array
112
"_discovered_nodes": List[model.NodeID],
113
# The subgraph of nodes discovered so far with annotated edges
114
# representing interactions that took place during the simulation. (See
115
# actions.EdgeAnnotation)
116
"_explored_network": networkx.DiGraph,
117
},
118
)
119
120
121
# Information returned to gym by the step function
122
StepInfo = TypedDict(
123
"StepInfo",
124
{
125
"description": str,
126
"duration_in_ms": float,
127
"step_count": int,
128
"network_availability": float,
129
# internal IDs of the credentials in the cache
130
"credential_cache": List[model.CachedCredential],
131
},
132
)
133
134
135
class OutOfBoundIndexError(Exception):
136
"""The agent attempted to reference an entity (node or a vulnerability) with an invalid index"""
137
138
139
Key = TypeVar("Key")
140
Value = TypeVar("Value")
141
142
143
def inverse_dict(self: Dict[Key, Value]) -> Dict[Value, Key]:
144
"""Inverse a dictionary"""
145
return {v: k for k, v in self.items()}
146
147
148
class DummySpace(spaces.Space):
149
"""This class ensures that the values in the gym.spaces.Dict space are derived from gymnasium.Space"""
150
151
def __init__(self, sample: object):
152
self._sample = sample
153
154
def contains(self, x: object) -> bool:
155
return True
156
157
def sample(self, mask=None) -> object:
158
return self._sample
159
160
161
def sourcenode_of_action(x: Action) -> int:
162
"""Return the source node of a given action"""
163
if "local_vulnerability" in x:
164
return x["local_vulnerability"][0]
165
elif "remote_vulnerability" in x:
166
return x["remote_vulnerability"][0]
167
168
assert "connect" in x
169
return x["connect"][0]
170
171
172
class EnvironmentBounds(NamedTuple):
173
"""Define global bounds posisibly shared by a set of CyberBattle gym environments
174
175
maximum_node_count - Maximum number of nodes in a given network
176
maximum_total_credentials - Maximum number of credentials in a given network
177
maximum_discoverable_credentials_per_action - Maximum number of credentials
178
that can be returned at a time by any action
179
180
port_count - Unique protocol ports
181
property_count - Unique node property names
182
local_attacks_count - Unique local vulnerabilities
183
remote_attacks_count - Unique remote vulnerabilities
184
"""
185
186
maximum_total_credentials: np.int32
187
maximum_node_count: np.int32
188
maximum_discoverable_credentials_per_action: np.int32
189
190
port_count: np.int32
191
property_count: np.int32
192
local_attacks_count: np.int32
193
remote_attacks_count: np.int32
194
195
@classmethod
196
def of_identifiers(
197
cls,
198
identifiers: model.Identifiers,
199
maximum_total_credentials: int,
200
maximum_node_count: int,
201
maximum_discoverable_credentials_per_action: Optional[int] = None,
202
):
203
204
maximum_discoverable_credentials_per_action = maximum_discoverable_credentials_per_action or maximum_total_credentials
205
206
assert np.can_cast(maximum_total_credentials, np.int32), "maximum_total_credentials must be a 32-bit integer"
207
assert np.can_cast(maximum_node_count, np.int32), "maximum_node_count must be a 32-bit integer"
208
assert maximum_total_credentials > 0, "maximum_total_credentials must be positive"
209
assert maximum_node_count > 0, "maximum_node_count must be positive"
210
assert np.can_cast(len(identifiers.ports), np.int32), "port_count must be a 32-bit integer"
211
assert np.can_cast(len(identifiers.properties), np.int32), "property_count must be a 32-bit integer"
212
assert np.can_cast(len(identifiers.local_vulnerabilities), np.int32), "local_attacks_count must be a 32-bit integer"
213
assert np.can_cast(len(identifiers.remote_vulnerabilities), np.int32), "remote_attacks_count must be a 32-bit integer"
214
assert np.can_cast(maximum_discoverable_credentials_per_action, np.int32), "maximum_discoverable_credentials_per_action must be a 32-bit integer"
215
216
return EnvironmentBounds(
217
maximum_total_credentials=np.int32(maximum_total_credentials),
218
maximum_node_count=np.int32(maximum_node_count),
219
maximum_discoverable_credentials_per_action=np.int32(maximum_discoverable_credentials_per_action),
220
port_count=np.int32(len(identifiers.ports)),
221
property_count=np.int32(len(identifiers.properties)),
222
local_attacks_count=np.int32(len(identifiers.local_vulnerabilities)),
223
remote_attacks_count=np.int32(len(identifiers.remote_vulnerabilities)),
224
)
225
226
227
class AttackerGoal(NamedTuple):
228
"""Define conditions to be simultanesouly met for the attacker to win.
229
If field values are not specified the default is to target full ownership
230
of the network nodes.
231
"""
232
233
# Include goal to reach at least the specifed cumulative total reward after
234
reward: float = 0.0
235
# Include goal to bring the availability to lower that the specified SLA value
236
low_availability: float = 1.0
237
# Include goal to own at least the specified number of nodes.
238
own_atleast: int = 0
239
# Include goal to own at least the specified percentage of the network nodes.
240
# Set to 1.0 to define goal as the ownership of all network nodes.
241
own_atleast_percent: float = 1.0
242
243
244
class DefenderGoal(NamedTuple):
245
"""Define conditions to be simultanesouly met for the defender to win."""
246
247
# Met if attacker is evicted from all the network nodes
248
eviction: bool
249
250
251
class DefenderConstraint(NamedTuple):
252
"""Define constraints to be maintained by the defender at all time."""
253
254
maintain_sla: float
255
256
257
class ObservationSpaceType(spaces.Dict):
258
def __init__(self, bounds: EnvironmentBounds):
259
super().__init__(
260
{
261
# how many new nodes were discovered
262
"newly_discovered_nodes_count": spaces.Discrete(NA + bounds.maximum_node_count),
263
# successuflly moved to the target node (1) or not (0)
264
"lateral_move": spaces.Discrete(2),
265
# boolean: 1 if customer secret data were discovered, 0 otherwise
266
"customer_data_found": spaces.Discrete(2),
267
# whether an attempted probing succeeded or not
268
"probe_result": spaces.Discrete(3),
269
# Esclation result
270
"escalation": spaces.Discrete(model.PrivilegeLevel.MAXIMUM + 1),
271
# Array of slots describing credentials that were leaked
272
"leaked_credentials": spaces.Tuple(
273
# the 1st component indicates if the slot is used or not (SLOT_USED or SLOT_UNSUED)
274
# the 2nd component gives the credential unique index (external identifier exposed to the agent)
275
# the 3rd component gives the target node ID
276
# the 4th component gives the port number
277
#
278
# The actual credential secret is not returned by the environment.
279
# To use the credential as a parameter to another action the agent should refer to it by its index
280
# e.g. (UNUSED_SLOT,_,_,_) encodes an empty slot
281
# (USED_SLOT,1,56,22) encodes a leaked credential identified by its index 1,
282
# that was used to authenticat to target node 56 on port number 22 (e.g. SSH)
283
[
284
spaces.MultiDiscrete(
285
np.array([
286
NA + 1,
287
bounds.maximum_total_credentials,
288
bounds.maximum_node_count,
289
bounds.port_count,
290
], dtype=np.int32)
291
)
292
]
293
* bounds.maximum_discoverable_credentials_per_action
294
),
295
# Boolean bitmasks defining the subset of valid actions in the current state.
296
# (1 for valid, 0 for invalid). Note: a valid action is not necessariliy guaranteed to succeed.
297
# For instance it is a valid action to attempt to connect to a remote node with incorrect credentials,
298
# even though such action would 'fail' and potentially yield a negative reward.
299
"action_mask": spaces.Dict(
300
{
301
"local_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.local_attacks_count])),
302
"remote_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.remote_attacks_count])),
303
"connect": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.port_count, bounds.maximum_total_credentials], dtype=np.int32))
304
}
305
),
306
# size of the credential stack
307
"credential_cache_length": spaces.Discrete(bounds.maximum_total_credentials),
308
# total nodes discovered so far
309
"discovered_node_count": spaces.Discrete(bounds.maximum_node_count),
310
# Matrix of properties for all the discovered nodes
311
# 3 values for each matrix cell: set, unset, unknown
312
"discovered_nodes_properties": spaces.MultiDiscrete(np.full(shape=(bounds.maximum_node_count, bounds.property_count), fill_value=3)),
313
# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)
314
"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),
315
# Encoding of the credential cache of shape: (credential_cache_length, 2)
316
#
317
# Each row represent a discovered credential,
318
# the credential index is given by the row index (i.e. order of discovery)
319
# A row is of the form: (target_node_discover_index, port_index)
320
"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete(np.array([bounds.maximum_node_count, bounds.port_count],dtype=np.int32))] * bounds.maximum_total_credentials),
321
# ---------------------------------------------------------
322
# Fields that were previously in the 'info' dict:
323
# ---------------------------------------------------------
324
# internal IDs of nodes discovered so far
325
"_discovered_nodes": DummySpace(sample=["node1", "node0", "node2"]),
326
# The subgraph of nodes discovered so far with annotated edges
327
# representing interactions that took place during the simulation. (See
328
# actions.EdgeAnnotation)
329
"_explored_network": DummySpace(sample=networkx.DiGraph()),
330
}
331
)
332
333
334
class CyberBattleSpaceKind(Env[Observation, Action]):
335
action_space: DiscriminatedUnion # type: ignore
336
observation_space: ObservationSpaceType # type: ignore
337
338
339
class CyberBattleEnv(CyberBattleSpaceKind):
340
"""OpenAI Gym environment interface to the CyberBattle simulation.
341
342
# Actions
343
344
Run a local attack: `(source_node x local_vulnerability_to_exploit)`
345
Run a remote attack command: `(source_node x target_node x remote_vulnerability_to_exploit)`
346
Connect to a remote node: `(source_node x target_node x target_port x credential_index_from_cache)`
347
348
# Observation
349
350
See type `Observation` for a full description of the observation space.
351
It includes:
352
- How many new nodes were discovered
353
- Whether lateral move succeeded
354
- Whether customer data were found
355
- Whehter escalation attempt succeeded
356
- Matrix of all node properties discovered so far
357
- List of leaked credentials
358
359
# Information
360
- Action mask indicating the subset of valid actions at the current state
361
362
# Termination
363
364
The simulation ends if either the attacker reaches its goal (e.g. full network ownership),
365
the defender reaches its goal (e.g. full eviction of the attacker)
366
or if one of the defender's constraints is not met (e.g. SLA).
367
"""
368
369
metadata = {"render_modes": ["human"]}
370
371
@property
372
def environment(self) -> model.Environment:
373
return self.__environment
374
375
def __reset_environment(self) -> None:
376
self.__environment: model.Environment = copy.deepcopy(self.__initial_environment)
377
self.__discovered_nodes: List[model.NodeID] = []
378
self.__owned_nodes_indices_cache: Optional[List[int]] = None
379
self.__credential_cache: List[model.CachedCredential] = []
380
self.__episode_rewards: List[float] = []
381
# The actuator used to execute actions in the simulation environment
382
self._actuator = actions.AgentActions(
383
self.__environment,
384
throws_on_invalid_actions=self.__throws_on_invalid_actions,
385
)
386
self._defender_actuator = actions.DefenderAgentActions(self.__environment)
387
388
self.__stepcount = 0
389
self.__start_time = time.time()
390
self.__done = False
391
392
for node_id, node_data in self.__environment.nodes():
393
if node_data.agent_installed:
394
self.__discovered_nodes.append(node_id)
395
396
@property
397
def name(self) -> str:
398
return "CyberBattleEnv"
399
400
@property
401
def identifiers(self) -> model.Identifiers:
402
return self.__environment.identifiers
403
404
@property
405
def bounds(self) -> EnvironmentBounds:
406
return self.__bounds
407
408
def validate_environment(self, environment: model.Environment):
409
"""Validate that the size of the network and associated constants fits within
410
the dimensions bounds set for the CyberBattle gym environment"""
411
assert environment.identifiers.ports
412
assert environment.identifiers.properties
413
assert environment.identifiers.local_vulnerabilities
414
assert environment.identifiers.remote_vulnerabilities
415
416
node_count = len(environment.network.nodes.items())
417
if node_count > self.__bounds.maximum_node_count:
418
raise ValueError(f"Network node count ({node_count}) exceeds " f"the specified limit of {self.__bounds.maximum_node_count}.")
419
420
# Maximum number of credentials that can possibly be returned by any action
421
effective_maximum_credentials_per_action = max(
422
[
423
len(vulnerability.outcome.credentials)
424
for _, node_info in environment.nodes()
425
for _, vulnerability in node_info.vulnerabilities.items()
426
if isinstance(vulnerability.outcome, model.LeakedCredentials)
427
]
428
)
429
430
if effective_maximum_credentials_per_action > self.__bounds.maximum_discoverable_credentials_per_action:
431
raise ValueError(
432
f"Some action in the environment returns {effective_maximum_credentials_per_action} "
433
f"credentials which exceeds the maximum number of discoverable credentials "
434
f"of {self.__bounds.maximum_discoverable_credentials_per_action}"
435
)
436
437
refeerenced_ports = model.collect_ports_from_environment(environment)
438
undefined_ports = set(refeerenced_ports).difference(environment.identifiers.ports)
439
if undefined_ports:
440
raise ValueError(f"The network has references to undefined port names: {undefined_ports}")
441
442
referenced_properties = model.collect_properties_from_nodes(model.iterate_network_nodes(environment.network))
443
undefined_properties = set(referenced_properties).difference(environment.identifiers.properties)
444
if undefined_properties:
445
raise ValueError(f"The network has references to undefined property names: {undefined_properties}")
446
447
local_vulnerabilities = model.collect_vulnerability_ids_from_nodes_bytype(
448
environment.nodes(),
449
environment.vulnerability_library,
450
model.VulnerabilityType.LOCAL,
451
)
452
453
undefined_local_vuln = set(local_vulnerabilities).difference(environment.identifiers.local_vulnerabilities)
454
if undefined_local_vuln:
455
raise ValueError(f"The network has references to undefined local" f" vulnerability names: {undefined_local_vuln}")
456
457
remote_vulnerabilities = model.collect_vulnerability_ids_from_nodes_bytype(
458
environment.nodes(),
459
environment.vulnerability_library,
460
model.VulnerabilityType.REMOTE,
461
)
462
463
undefined_remote_vuln = set(remote_vulnerabilities).difference(environment.identifiers.remote_vulnerabilities)
464
if undefined_remote_vuln:
465
raise ValueError(f"The network has references to undefined remote" f" vulnerability names: {undefined_remote_vuln}")
466
467
# number of distinct privilege levels
468
privilege_levels = model.PrivilegeLevel.MAXIMUM + 1
469
470
def __init__(
471
self,
472
initial_environment: model.Environment,
473
maximum_total_credentials: int = 1000,
474
maximum_node_count: int = 100,
475
maximum_discoverable_credentials_per_action: int = 5,
476
defender_agent: Optional[DefenderAgent] = None,
477
attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),
478
defender_goal=DefenderGoal(eviction=True),
479
defender_constraint=DefenderConstraint(maintain_sla=0.0),
480
winning_reward=5000.0,
481
losing_reward=0.0,
482
renderer="",
483
observation_padding=True,
484
throws_on_invalid_actions=True,
485
):
486
"""Arguments
487
===========
488
environment - The CyberBattle network simulation environment
489
maximum_total_credentials - Maximum total number of credentials used in a network
490
maximum_node_count - Largest possible size of the network
491
maximum_discoverable_credentials_per_action - Maximum number of credentials returned by a given action
492
attacker_goal - Target goal for the attacker to win and stop the simulation.
493
defender_goal - Target goal for the defender to win and stop the simulation.
494
defender_constraint - Constraint to be maintain by the defender to keep the simulation running.
495
winning_reward - Reward granted to the attacker if the simulation ends because the attacker's goal is reached.
496
losing_reward - Reward granted to the attacker if the simulation ends because the Defender's goal is reached.
497
renderer - the matplotlib renderer (e.g. 'png')
498
observation_padding - whether to pad all the observation fields to their maximum size. For instance this will pad the credential matrix
499
to fit in `maximum_node_count` rows. Turn on this flag for gym agent that expects observations of fixed sizes.
500
must be set to True with gym >=0.26
501
throws_on_invalid_actions - whether to raise an exception if the step function attempts an invalid action (e.g., running an attack from a node that's not owned)
502
if set to False a negative reward is returned instead.
503
"""
504
505
# maximum number of entities in a given environment
506
self.__bounds = EnvironmentBounds.of_identifiers(
507
maximum_total_credentials=maximum_total_credentials,
508
maximum_node_count=maximum_node_count,
509
maximum_discoverable_credentials_per_action=maximum_discoverable_credentials_per_action,
510
identifiers=initial_environment.identifiers,
511
)
512
513
self.validate_environment(initial_environment)
514
self.__attacker_goal: Optional[AttackerGoal] = attacker_goal
515
self.__defender_goal: DefenderGoal = defender_goal
516
self.__defender_constraint: DefenderConstraint = defender_constraint
517
self.__WINNING_REWARD = winning_reward
518
self.__LOSING_REWARD = losing_reward
519
self.__renderer = renderer
520
self.__observation_padding = observation_padding
521
self.__throws_on_invalid_actions = throws_on_invalid_actions
522
523
self.viewer = None
524
525
self.__initial_environment: model.Environment = initial_environment
526
527
# number of entities in the environment network
528
self.__defender_agent = defender_agent
529
530
self.__reset_environment()
531
532
self.__node_count = len(initial_environment.network.nodes.items())
533
534
# The Space object defining the valid actions of an attacker.
535
local_vulnerabilities_count = self.__bounds.local_attacks_count
536
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
537
maximum_node_count_int32 = self.__bounds.maximum_node_count
538
port_count = self.__bounds.port_count
539
540
action_spaces = {
541
"local_vulnerability": spaces.MultiDiscrete(
542
# source_node_id, vulnerability_id
543
np.array([maximum_node_count_int32, local_vulnerabilities_count], dtype=np.int32)
544
),
545
"remote_vulnerability": spaces.MultiDiscrete(
546
# source_node_id, target_node_id, vulnerability_id
547
np.array([maximum_node_count_int32, maximum_node_count_int32, remote_vulnerabilities_count], dtype=np.int32)
548
),
549
"connect": spaces.MultiDiscrete(
550
# source_node_id, target_node_id, target_port, credential_id
551
# (by index of discovery: 0 for initial node, 1 for first discovered node, ...)
552
np.array([
553
maximum_node_count_int32,
554
maximum_node_count_int32,
555
port_count,
556
maximum_total_credentials,
557
], dtype=np.int32)
558
),
559
}
560
561
self.action_space = DiscriminatedUnion[Action](cast(dict, action_spaces)) # type: ignore
562
563
self.observation_space = ObservationSpaceType(self.__bounds)
564
565
# reward_range: A tuple corresponding to the min and max possible rewards
566
self.reward_range = (-float("inf"), float("inf"))
567
568
def __index_to_local_vulnerabilityid(self, vulnerability_index: int) -> model.VulnerabilityID:
569
"""Return the local vulnerability identifier from its internal encoding index"""
570
return self.__initial_environment.identifiers.local_vulnerabilities[vulnerability_index]
571
572
def __index_to_remote_vulnerabilityid(self, vulnerability_index: int) -> model.VulnerabilityID:
573
"""Return the remote vulnerability identifier from its internal encoding index"""
574
return self.__initial_environment.identifiers.remote_vulnerabilities[vulnerability_index]
575
576
def __index_to_port_name(self, port_index: int) -> model.PortName:
577
"""Return the port name identifier from its internal encoding index"""
578
return self.__initial_environment.identifiers.ports[port_index]
579
580
def __portname_to_index(self, port_name: PortName) -> int:
581
"""Return the internal encoding index of a given port name"""
582
return self.__initial_environment.identifiers.ports.index(port_name)
583
584
def __internal_node_id_from_external_node_index(self, node_external_index: int) -> model.NodeID:
585
""" "Return the internal environment node ID corresponding to the specified
586
external node index that is exposed to the Gym agent
587
0 -> ID of inital node
588
1 -> ID of first discovered node
589
...
590
591
"""
592
# Ensures that the specified node is known by the agent
593
if node_external_index < 0:
594
raise OutOfBoundIndexError(f"Node index must be positive, given {node_external_index}")
595
596
length = len(self.__discovered_nodes)
597
if node_external_index >= length:
598
raise OutOfBoundIndexError(f"Node index ({node_external_index}) is invalid; only {length} nodes discovered so far.")
599
600
node_id = self.__discovered_nodes[node_external_index]
601
return node_id
602
603
def __find_external_index(self, node_id: model.NodeID) -> int:
604
"""Find the external index associated with the specified node ID"""
605
return self.__discovered_nodes.index(node_id)
606
607
def __agent_owns_node(self, node_id: model.NodeID) -> bool:
608
node = self.__environment.get_node(node_id)
609
pwned: bool = node.agent_installed
610
return pwned
611
612
def apply_mask(self, action: Action, mask: Optional[ActionMask] = None) -> bool:
613
"""Apply the action mask to a specific action. Returns true just if the action
614
is permitted."""
615
if mask is None:
616
mask = self.compute_action_mask()
617
field_name = DiscriminatedUnion.kind(action)
618
field_mask, coordinates = mask[field_name], action[field_name] # type: ignore
619
return bool(field_mask[tuple(coordinates)])
620
621
def __get_blank_action_mask(self) -> ActionMask:
622
"""Return a blank action mask"""
623
max_node_count = self.bounds.maximum_node_count
624
local_vulnerabilities_count = self.__bounds.local_attacks_count
625
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
626
port_count = self.__bounds.port_count
627
local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int8)
628
remote = numpy.zeros(
629
shape=(max_node_count, max_node_count, remote_vulnerabilities_count),
630
dtype=numpy.int8,
631
)
632
connect = numpy.zeros(
633
shape=(
634
max_node_count,
635
max_node_count,
636
port_count,
637
self.__bounds.maximum_total_credentials,
638
),
639
dtype=numpy.int8,
640
)
641
return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)
642
643
def __update_action_mask(self, bitmask: ActionMask) -> None:
644
"""Update an action mask based on the current state"""
645
local_vulnerabilities_count = self.__bounds.local_attacks_count
646
remote_vulnerabilities_count = self.__bounds.remote_attacks_count
647
port_count = self.__bounds.port_count
648
649
# Compute the vulnerability action bitmask
650
#
651
# The agent may attempt exploiting vulnerabilities
652
# from any node that it owns
653
for source_node_id in self.__discovered_nodes:
654
if self.__agent_owns_node(source_node_id):
655
source_index = self.__find_external_index(source_node_id)
656
657
# Local: since the agent owns the node, all its local vulnerabilities are visible to it
658
for vulnerability_index in range(local_vulnerabilities_count):
659
vulnerability_id = self.__index_to_local_vulnerabilityid(vulnerability_index)
660
node_vulnerable = vulnerability_id in self.__environment.vulnerability_library or vulnerability_id in self.__environment.get_node(source_node_id).vulnerabilities
661
662
if node_vulnerable:
663
bitmask["local_vulnerability"][source_index, vulnerability_index] = 1
664
665
# Remote: Any other node discovered so far is a potential remote target
666
for target_node_id in self.__discovered_nodes:
667
target_index = self.__find_external_index(target_node_id)
668
bitmask["remote_vulnerability"][source_index, target_index, :remote_vulnerabilities_count] = 1
669
670
# the agent may attempt to connect to any port
671
# and use any credential from its cache (though it's not guaranteed to succeed)
672
bitmask["connect"][
673
source_index,
674
target_index,
675
:port_count,
676
: len(self.__credential_cache),
677
] = 1
678
679
def compute_action_mask(self) -> ActionMask:
680
"""Compute the action mask for the current state"""
681
bitmask = self.__get_blank_action_mask()
682
self.__update_action_mask(bitmask)
683
return bitmask
684
685
def pretty_print_internal_action(self, action: Action) -> str:
686
"""Pretty print an action with internal node and vulnerability identifiers"""
687
assert 1 == len(action.keys())
688
assert DiscriminatedUnion.kind(action) != ""
689
if "local_vulnerability" in action:
690
source_node_index, vulnerability_index = action["local_vulnerability"]
691
return f"local_vulnerability(`{self.__internal_node_id_from_external_node_index(source_node_index)}, {self.__index_to_local_vulnerabilityid(vulnerability_index)})"
692
elif "remote_vulnerability" in action:
693
source_node, target_node, vulnerability_index = action["remote_vulnerability"]
694
source_node_id = self.__internal_node_id_from_external_node_index(source_node)
695
target_node_id = self.__internal_node_id_from_external_node_index(target_node)
696
return f"remote_vulnerability(`{source_node_id}, `{target_node_id}, {self.__index_to_remote_vulnerabilityid(vulnerability_index)})"
697
elif "connect" in action:
698
source_node, target_node, port_index, credential_cache_index = action["connect"]
699
assert credential_cache_index >= 0
700
if credential_cache_index >= len(self.__credential_cache):
701
return "connect(invalid)"
702
source_node_id = self.__internal_node_id_from_external_node_index(source_node)
703
target_node_id = self.__internal_node_id_from_external_node_index(target_node)
704
return f"connect(`{source_node_id}, `{target_node_id}, {self.__index_to_port_name(port_index)}, {self.__credential_cache[credential_cache_index].credential})"
705
raise ValueError("Invalid discriminated union value: " + str(action))
706
707
def __execute_action(self, action: Action) -> actions.ActionResult:
708
# Assert that the specified action is consistent (i.e., defining a single action type)
709
assert 1 == len(action.keys())
710
711
assert DiscriminatedUnion.kind(action) != ""
712
713
if "local_vulnerability" in action:
714
source_node_index, vulnerability_index = action["local_vulnerability"]
715
716
return self._actuator.exploit_local_vulnerability(
717
self.__internal_node_id_from_external_node_index(source_node_index),
718
self.__index_to_local_vulnerabilityid(vulnerability_index),
719
)
720
721
elif "remote_vulnerability" in action:
722
source_node, target_node, vulnerability_index = action["remote_vulnerability"]
723
source_node_id = self.__internal_node_id_from_external_node_index(source_node)
724
target_node_id = self.__internal_node_id_from_external_node_index(target_node)
725
726
result = self._actuator.exploit_remote_vulnerability(
727
source_node_id,
728
target_node_id,
729
self.__index_to_remote_vulnerabilityid(vulnerability_index),
730
)
731
732
return result
733
734
elif "connect" in action:
735
source_node, target_node, port_index, credential_cache_index = action["connect"]
736
if credential_cache_index < 0 or credential_cache_index >= len(self.__credential_cache):
737
return actions.ActionResult(reward=-1, outcome=None)
738
739
source_node_id = self.__internal_node_id_from_external_node_index(source_node)
740
target_node_id = self.__internal_node_id_from_external_node_index(target_node)
741
742
result = self._actuator.connect_to_remote_machine(
743
source_node_id,
744
target_node_id,
745
self.__index_to_port_name(port_index),
746
self.__credential_cache[credential_cache_index].credential,
747
)
748
749
return result
750
751
raise ValueError("Invalid discriminated union value: " + str(action))
752
753
def __get_blank_observation(self) -> Observation:
754
observation = Observation(
755
newly_discovered_nodes_count=numpy.int32(0),
756
leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),
757
lateral_move=numpy.int32(0),
758
customer_data_found=numpy.int32(0),
759
escalation=numpy.int32(PrivilegeLevel.NoAccess),
760
action_mask=self.__get_blank_action_mask(),
761
probe_result=numpy.int32(0),
762
credential_cache_matrix=tuple([numpy.zeros((2), dtype=numpy.int64)] * self.__bounds.maximum_total_credentials),
763
credential_cache_length=0,
764
discovered_node_count=len(self.__discovered_nodes),
765
discovered_nodes_properties=numpy.full((self.__bounds.maximum_node_count, self.__bounds.property_count,), 2, dtype=numpy.int32),
766
nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),
767
# raw data not actually encoded as a proper gym numeric space
768
# (were previously returned in the 'info' dict)
769
_discovered_nodes=self.__discovered_nodes,
770
_explored_network=self.__get_explored_network(),
771
)
772
773
return observation
774
775
def __pad_array_if_requested(self, o, pad_value, desired_length) -> numpy.ndarray:
776
"""Pad an array observation with provided padding if the padding option is enabled
777
for this environment"""
778
if self.__observation_padding:
779
padding = numpy.full((desired_length - len(o)), pad_value, dtype=numpy.int32)
780
return numpy.concatenate((o, padding))
781
else:
782
return o
783
784
def __pad_tuple_if_requested(self, o, row_shape, desired_length) -> Tuple[numpy.ndarray, ...]:
785
"""Pad a tuple observation with provided padding if the padding option is enabled
786
for this environment"""
787
if self.__observation_padding:
788
padding = [numpy.zeros(row_shape, dtype=numpy.int32)] * (desired_length - len(o))
789
return tuple(o + padding)
790
else:
791
return tuple(o)
792
793
def __property_vector(self, node_id: model.NodeID, node_info: model.NodeInfo) -> numpy.ndarray:
794
"""Property vector for specified node
795
each cell is either 1 if the property is set, 0 if unset, and 2 if unknown (node is not owned by the agent yet)
796
"""
797
properties_indices = list(self._actuator.get_discovered_properties(node_id))
798
799
is_owned = self._actuator.get_node_privilegelevel(node_id) >= PrivilegeLevel.LocalUser
800
801
if is_owned:
802
# if the node is owned then we know all its properties
803
vector = numpy.full((self.__bounds.property_count), 0, dtype=numpy.int32)
804
else:
805
# otherwise we don't know anything about not discovered properties => 0 should be the default value
806
vector = numpy.zeros((self.__bounds.property_count), dtype=numpy.int32)
807
808
vector[properties_indices] = 1
809
return vector
810
811
def __get_property_matrix(self) -> numpy.ndarray:
812
"""Return the Node-Property matrix,
813
where 0 means the property is not set for that node
814
1 means the property is set for that node
815
2 means the property status is unknown
816
817
e.g.: [ 1 0 0 1 ]
818
2 2 2 2
819
0 1 0 1 ]
820
1st row: set and unset properties for the 1st discovered and owned node
821
2nd row: no known properties for the 2nd discovered node
822
3rd row: properties of 3rd discovered and owned node"""
823
property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]
824
as_numpy = numpy.array(self.__pad_tuple_if_requested(
825
property_discovered,
826
self.__bounds.property_count,
827
self.__bounds.maximum_node_count,
828
))
829
assert as_numpy.shape == (self.__bounds.maximum_node_count, self.__bounds.property_count)
830
return as_numpy
831
832
def __get__owned_nodes_indices(self) -> List[int]:
833
"""Get list of indices of all owned nodes"""
834
if self.__owned_nodes_indices_cache is None:
835
owned_nodeids = self._actuator.get_nodes_with_atleast_privilegelevel(PrivilegeLevel.LocalUser)
836
self.__owned_nodes_indices_cache = [self.__find_external_index(n) for n in owned_nodeids]
837
838
return self.__owned_nodes_indices_cache
839
840
def __get_privilegelevel_array(self) -> numpy.ndarray:
841
"""Return the node escalation level array,
842
where 0 means that the node is not owned
843
1 if the node is owned
844
2 if the node is owned and escalated to admin
845
3 if the node is owned and escalated to SYSTEM
846
... further escalation levels defined by the network
847
"""
848
privilegelevel_array = numpy.array(
849
[int(self._actuator.get_node_privilegelevel(node)) for node in self.__discovered_nodes],
850
dtype=numpy.int32,
851
)
852
853
return self.__pad_array_if_requested(
854
privilegelevel_array,
855
PrivilegeLevel.NoAccess,
856
self.__bounds.maximum_node_count,
857
)
858
859
def __observation_reward_from_action_result(self, result: actions.ActionResult) -> Tuple[Observation, float]:
860
obs = self.__get_blank_observation()
861
outcome = result.outcome
862
863
if isinstance(outcome, model.LeakedNodesId):
864
# update discovered nodes
865
newly_discovered_nodes_count = 0
866
for node in outcome.nodes:
867
if node not in self.__discovered_nodes:
868
self.__discovered_nodes.append(node)
869
newly_discovered_nodes_count += 1
870
871
obs["newly_discovered_nodes_count"] = numpy.int32(newly_discovered_nodes_count)
872
873
elif isinstance(outcome, model.LeakedCredentials):
874
# update discovered nodes and credentials
875
newly_discovered_nodes_count = 0
876
newly_discovered_creds: List[Tuple[int, model.CachedCredential]] = []
877
for cached_credential in outcome.credentials:
878
if cached_credential.node not in self.__discovered_nodes:
879
self.__discovered_nodes.append(cached_credential.node)
880
newly_discovered_nodes_count += 1
881
882
if cached_credential not in self.__credential_cache:
883
self.__credential_cache.append(cached_credential)
884
added_credential_index = len(self.__credential_cache) - 1
885
newly_discovered_creds.append((added_credential_index, cached_credential))
886
887
obs["newly_discovered_nodes_count"] = numpy.int32(newly_discovered_nodes_count)
888
889
# Encode the returned credentials in the format expected by the gym agent
890
leaked_credentials = [
891
numpy.array(
892
[
893
USED_SLOT,
894
cache_index,
895
self.__find_external_index(cached_credential.node),
896
self.__portname_to_index(cached_credential.port),
897
],
898
numpy.int32,
899
)
900
for cache_index, cached_credential in newly_discovered_creds
901
]
902
903
obs["leaked_credentials"] = self.__pad_tuple_if_requested(
904
leaked_credentials,
905
4,
906
self.__bounds.maximum_discoverable_credentials_per_action,
907
)
908
909
elif isinstance(outcome, model.LateralMove):
910
obs["lateral_move"] = numpy.int32(1)
911
elif isinstance(outcome, model.CustomerData):
912
obs["customer_data_found"] = numpy.int32(1)
913
elif isinstance(outcome, model.ProbeSucceeded):
914
obs["probe_result"] = numpy.int32(2)
915
elif isinstance(outcome, model.ProbeFailed):
916
obs["probe_result"] = numpy.int32(1)
917
elif isinstance(outcome, model.PrivilegeEscalation):
918
obs["escalation"] = numpy.int32(outcome.level)
919
920
cache = [numpy.array([self.__find_external_index(c.node), self.__portname_to_index(c.port)]) for c in self.__credential_cache]
921
922
obs["credential_cache_matrix"] = self.__pad_tuple_if_requested(cache, 2, self.__bounds.maximum_total_credentials)
923
924
# Dynamic statistics to be refreshed
925
obs["credential_cache_length"] = len(self.__credential_cache)
926
obs["discovered_node_count"] = len(self.__discovered_nodes)
927
obs["discovered_nodes_properties"] = self.__get_property_matrix()
928
obs["nodes_privilegelevel"] = self.__get_privilegelevel_array()
929
obs["_discovered_nodes"] = self.__discovered_nodes
930
obs["_explored_network"] = self.__get_explored_network()
931
932
self.__update_action_mask(obs["action_mask"])
933
return obs, result.reward
934
935
def sample_connect_action_in_expected_range(self) -> Action:
936
"""Sample an action of type 'connect' where the parameters
937
are in the the expected ranges but not necessarily verifying
938
inter-component constraints.
939
"""
940
discovered_credential_count = len(self.__credential_cache)
941
942
if discovered_credential_count <= 0:
943
raise ValueError("Cannot sample a connect action until the agent discovers more potential target nodes.")
944
945
return Action(
946
connect=numpy.array(
947
[
948
self.np_random.choice(self.__get__owned_nodes_indices()),
949
self.np_random.integers(0, len(self.__discovered_nodes)),
950
self.np_random.integers(0, self.__bounds.port_count),
951
# credential space is sparse so we force sampling
952
# from the set of credentials that were discovered so far
953
self.np_random.integers(0, len(self.__credential_cache)),
954
],
955
numpy.int32,
956
)
957
)
958
959
def sample_action_in_range(self, kinds: Optional[List[int]] = None) -> Action:
960
"""Sample an action in the expected component ranges but
961
not necessarily verifying inter-component constraints.
962
(e.g., may return a local_vulnerability action that is not
963
supported by the node)
964
965
- kinds -- A list of elements in {0,1,2} indicating what kind of
966
action to sample (0:local, 1:remote, 2:connect)
967
"""
968
969
discovered_credential_count = len(self.__credential_cache)
970
971
if kinds is None:
972
kinds = [0, 1, 2]
973
974
if discovered_credential_count == 0:
975
# cannot generate a connect action if no cred in the cache
976
kinds = [t for t in kinds if t != 2]
977
978
assert kinds, "Kinds list cannot be empty"
979
980
choice_random = self.action_space.union_np_random
981
kind = choice_random.choice(kinds)
982
983
if kind == 2:
984
action = self.sample_connect_action_in_expected_range()
985
elif kind == 1:
986
action = Action(
987
local_vulnerability=numpy.array(
988
[
989
choice_random.choice(self.__get__owned_nodes_indices()),
990
choice_random.integers(0, self.__bounds.local_attacks_count),
991
],
992
numpy.int32,
993
)
994
)
995
else:
996
action = Action(
997
remote_vulnerability=numpy.array(
998
[
999
choice_random.choice(self.__get__owned_nodes_indices()),
1000
choice_random.integers(0, len(self.__discovered_nodes)),
1001
choice_random.integers(0, self.__bounds.remote_attacks_count),
1002
],
1003
numpy.int32,
1004
)
1005
)
1006
1007
return action
1008
1009
def is_node_owned(self, node: int):
1010
"""Return true if a discovered node (specified by its external node index)
1011
is owned by the attacker agent"""
1012
node_id = self.__internal_node_id_from_external_node_index(node)
1013
node_owned = self._actuator.get_node_privilegelevel(node_id) > PrivilegeLevel.NoAccess
1014
return node_owned
1015
1016
def is_action_valid(self, action, action_mask: Optional[ActionMask] = None) -> bool:
1017
"""Determine if an action is valid (i.e. parameters are in expected ranges)"""
1018
assert 1 == len(action.keys())
1019
1020
kind = DiscriminatedUnion.kind(action)
1021
in_range = False
1022
n_discovered_nodes = len(self.__discovered_nodes)
1023
if kind == "local_vulnerability":
1024
source_node, vulnerability_index = action["local_vulnerability"]
1025
in_range = source_node < n_discovered_nodes and self.is_node_owned(source_node) and vulnerability_index < self.__bounds.local_attacks_count
1026
elif kind == "remote_vulnerability":
1027
source_node, target_node, vulnerability_index = action["remote_vulnerability"]
1028
in_range = source_node < n_discovered_nodes and self.is_node_owned(source_node) and target_node < n_discovered_nodes and vulnerability_index < self.__bounds.remote_attacks_count
1029
elif kind == "connect":
1030
source_node, target_node, port_index, credential_cache_index = action["connect"]
1031
in_range = (
1032
source_node < n_discovered_nodes
1033
and self.is_node_owned(source_node)
1034
and target_node < n_discovered_nodes
1035
and port_index < self.__bounds.port_count
1036
and credential_cache_index < len(self.__credential_cache)
1037
)
1038
1039
return in_range and self.apply_mask(action, action_mask)
1040
1041
def sample_valid_action(self, kinds=None) -> Action:
1042
"""Sample an action within the expected ranges until getting a valid one"""
1043
action_mask = self.compute_action_mask()
1044
action = self.sample_action_in_range(kinds)
1045
while not self.apply_mask(action, action_mask):
1046
action = self.sample_action_in_range(kinds)
1047
return action
1048
1049
def sample_valid_action_with_luck(self) -> Action:
1050
"""Sample an action until getting a valid one"""
1051
action_mask = self.compute_action_mask()
1052
action = self.action_space.sample()
1053
while not self.apply_mask(action, action_mask):
1054
action = self.action_space.sample()
1055
return action
1056
1057
def __get_explored_network(self) -> networkx.DiGraph:
1058
"""Returns the graph of nodes discovered so far
1059
with annotated edges representing interactions
1060
that took place during the simulation.
1061
"""
1062
known_nodes = [node_id for node_id, _ in self._actuator.discovered_nodes()]
1063
network = self.__environment.network
1064
subgraph = network.subgraph(known_nodes).copy()
1065
assert isinstance(subgraph, networkx.DiGraph)
1066
1067
# hide info for nodes that the agent does not own
1068
for node_id in subgraph.nodes:
1069
node_info: Optional[model.NodeInfo] = subgraph.nodes[node_id]["data"]
1070
if node_info is not None:
1071
if not node_info.agent_installed:
1072
subgraph.nodes[node_id]["data"] = None
1073
1074
subgraph.nodes[node_id]["privilege_level"] = int(self._actuator.get_node_privilegelevel(node_id))
1075
subgraph.nodes[node_id]["flags"] = list(self._actuator.get_discovered_properties(node_id))
1076
subgraph.nodes[node_id]["flags_bits"] = self.__property_vector(node_id, node_info)
1077
1078
return subgraph
1079
1080
def __attacker_goal_reached(self) -> bool:
1081
goal = self.__attacker_goal
1082
1083
if not goal:
1084
return False
1085
1086
if numpy.sum(self.__episode_rewards) < goal.reward:
1087
return False
1088
1089
nodes_owned = self.__get__owned_nodes_indices()
1090
owned_count = len(nodes_owned)
1091
1092
if owned_count < goal.own_atleast:
1093
return False
1094
1095
if owned_count / self.__node_count < goal.own_atleast_percent:
1096
return False
1097
1098
if self.__defender_agent is not None and self._defender_actuator.network_availability >= goal.low_availability:
1099
return False
1100
1101
return True
1102
1103
def __defender_constraints_broken(self) -> bool:
1104
"""Check if any of the defender's constraint is not met"""
1105
constraint = self.__defender_constraint
1106
1107
if self.__defender_agent is not None and self._defender_actuator.network_availability < constraint.maintain_sla:
1108
return True
1109
1110
return False
1111
1112
def __defender_goal_reached(self) -> bool:
1113
"""Check if defender's goal is reached(e.g. full eviction of attacker)"""
1114
goal = self.__defender_goal
1115
1116
return goal.eviction and not (self.__get__owned_nodes_indices())
1117
1118
def get_explored_network_as_numpy(self, observation: Observation) -> numpy.ndarray:
1119
"""Return the explored network graph adjacency matrix
1120
as an numpy array of shape (N,N) where
1121
N is the number of nodes discovered so far"""
1122
return convert_matrix.to_numpy_array(observation["_explored_network"], weight="kind_as_float")
1123
1124
def get_explored_network_node_properties_bitmap_as_numpy(self, observation: Observation) -> numpy.ndarray:
1125
"""Return a combined the matrix of adjacencies (left part) and
1126
node properties bitmap (right part).
1127
Suppose N is the number of discovered nodes and
1128
P is the total number of properties then
1129
Then the return matrix is of the form:
1130
1131
^ <---- N -----><------ P ------>
1132
| ( | )
1133
N ( Adjacency | Node-Properties )
1134
| ( Matrix | Bitmap )
1135
V ( | )
1136
1137
"""
1138
return numpy.block(
1139
[
1140
convert_matrix.to_numpy_array(observation["_explored_network"], weight="kind_as_float"),
1141
numpy.array(observation["discovered_nodes_properties"]),
1142
]
1143
)
1144
1145
def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: # type: ignore
1146
if self.__done:
1147
raise RuntimeError("new episode must be started with env.reset()")
1148
1149
self.__stepcount += 1
1150
duration = time.time() - self.__start_time
1151
try:
1152
result = self.__execute_action(action)
1153
observation, reward = self.__observation_reward_from_action_result(result)
1154
1155
# Execute the defender step if provided
1156
if self.__defender_agent:
1157
self._defender_actuator.on_attacker_step_taken()
1158
self.__defender_agent.step(self.__environment, self._defender_actuator, self.__stepcount)
1159
1160
self.__owned_nodes_indices_cache = None
1161
1162
if self.__attacker_goal_reached() or self.__defender_constraints_broken():
1163
self.__done = True
1164
reward = self.__WINNING_REWARD
1165
elif self.__defender_goal_reached():
1166
self.__done = True
1167
reward = self.__LOSING_REWARD
1168
else:
1169
reward = max(0.0, reward)
1170
1171
except OutOfBoundIndexError as error:
1172
logging.warning("Invalid entity index: " + error.__str__())
1173
observation = self.__get_blank_observation()
1174
reward = 0.0
1175
1176
info = StepInfo(
1177
description="CyberBattle simulation",
1178
duration_in_ms=duration,
1179
step_count=self.__stepcount,
1180
network_availability=self._defender_actuator.network_availability,
1181
credential_cache=self.__credential_cache,
1182
)
1183
self.__episode_rewards.append(reward)
1184
1185
return observation, reward, self.__done, False, info
1186
1187
def reset(
1188
self,
1189
*,
1190
seed: Optional[int] = None,
1191
options: Optional[dict] = None,
1192
) -> Tuple[Observation, StepInfo]:
1193
LOGGER.info("Resetting the CyberBattle environment")
1194
self.__reset_environment()
1195
self.np_random, seed = seeding.np_random(seed)
1196
1197
observation = self.__get_blank_observation()
1198
observation["action_mask"] = self.compute_action_mask()
1199
observation["discovered_nodes_properties"] = self.__get_property_matrix()
1200
observation["nodes_privilegelevel"] = self.__get_privilegelevel_array()
1201
self.__owned_nodes_indices_cache = None
1202
info = StepInfo(
1203
description="CyberBattle simulation",
1204
duration_in_ms=0,
1205
step_count=self.__stepcount,
1206
network_availability=self._defender_actuator.network_availability,
1207
credential_cache=self.__credential_cache,
1208
)
1209
return observation, info
1210
1211
def render_as_fig(self):
1212
debug = commandcontrol.EnvironmentDebugging(self._actuator)
1213
self._actuator.print_all_attacks()
1214
1215
# plot the cumulative reward and network side by side using plotly
1216
fig = make_subplots(rows=1, cols=2)
1217
fig.add_trace(
1218
Scatter(y=numpy.array(self.__episode_rewards).cumsum(), name="cumulative reward"),
1219
row=1,
1220
col=1,
1221
)
1222
traces, layout = debug.network_as_plotly_traces(xref="x2", yref="y2")
1223
for t in traces:
1224
fig.add_trace(t, row=1, col=2)
1225
fig.update_layout(layout)
1226
return fig
1227
1228
def render(self, mode: str = "human") -> None:
1229
fig = self.render_as_fig()
1230
fig.show(renderer=self.__renderer)
1231
1232
def close(self) -> None:
1233
return None
1234
1235