Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/simulation/actions.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""
5
actions.py
6
This file contains the class and associated methods for the AgentActions
7
class which interacts directly with the environment. It is the class
8
which both the user and RL agents should manipulate the environment.
9
"""
10
11
from dataclasses import dataclass
12
import dataclasses
13
import datetime
14
from boolean import boolean
15
from collections import OrderedDict
16
import logging
17
from enum import Enum
18
from typing import (
19
Iterator,
20
List,
21
NamedTuple,
22
Optional,
23
Set,
24
Tuple,
25
Dict,
26
TypedDict,
27
cast,
28
)
29
from IPython.display import display
30
import pandas as pd
31
32
from cyberbattle.simulation.model import (
33
FirewallRule,
34
MachineStatus,
35
PrivilegeLevel,
36
PropertyName,
37
VulnerabilityID,
38
VulnerabilityType,
39
)
40
from . import model
41
42
43
logger = logging.getLogger(__name__)
44
Reward = float
45
46
DiscoveredNodeInfo = TypedDict("DiscoveredNodeInfo", {"id": model.NodeID, "status": str})
47
48
49
class Penalty:
50
"""Penalties (=negative reward) returned for some actions taken in the simulation"""
51
52
# penalty for generic suspiciousness
53
SUPSPICIOUSNESS = -5.0
54
55
# penalty for attempting a connection to a port that was not open
56
SCANNING_UNOPEN_PORT = -10.0
57
58
# penalty for repeating the same exploit attempt
59
REPEAT = -1
60
61
LOCAL_EXPLOIT_FAILED = -20
62
FAILED_REMOTE_EXPLOIT = -50
63
64
# penalty for attempting to connect or execute an action on a node that's not in running state
65
MACHINE_NOT_RUNNING = 0
66
67
# penalty for attempting a connection with an invalid password
68
WRONG_PASSWORD = -10
69
70
# traffic blocked by outoing rule in a local firewall
71
BLOCKED_BY_LOCAL_FIREWALL = -10
72
73
# traffic blocked by incoming rule in a remote firewall
74
BLOCKED_BY_REMOTE_FIREWALL = -10
75
76
# invalid action (e.g., running an attack from a node that's not owned)
77
# (Used only if `throws_on_invalid_actions` is set to False)
78
INVALID_ACTION = -1
79
80
81
# Reward for the first time a local or remote attack
82
# gets successfully executed since the last time the target node was imaged.
83
# NOTE: the attack cost gets substracted from this reward.
84
NEW_SUCCESSFULL_ATTACK_REWARD = 7
85
86
# Fixed reward for discovering a new node
87
NODE_DISCOVERED_REWARD = 5
88
89
# Fixed reward for discovering a new credential
90
CREDENTIAL_DISCOVERED_REWARD = 3
91
92
# Fixed reward for discovering a new node property
93
PROPERTY_DISCOVERED_REWARD = 2
94
95
96
class EdgeAnnotation(Enum):
97
"""Annotation added to the network edges created as the simulation is played"""
98
99
KNOWS = 0
100
REMOTE_EXPLOIT = 1
101
LATERAL_MOVE = 2
102
103
104
class ActionResult(NamedTuple):
105
"""Result from executing an action"""
106
107
reward: Reward
108
outcome: Optional[model.VulnerabilityOutcome]
109
110
111
ALGEBRA = boolean.BooleanAlgebra()
112
113
114
@dataclass
115
class NodeTrackingInformation:
116
"""Track information about nodes gathered throughout the simulation"""
117
118
# Map (vulnid, local_or_remote) to time of last attack.
119
# local_or_remote is true for local, false for remote
120
last_attack: Dict[Tuple[model.VulnerabilityID, bool], datetime.datetime] = dataclasses.field(default_factory=dict)
121
# Last time the node got owned by the attacker agent
122
last_owned_at: Optional[datetime.datetime] = None
123
# All node properties discovered so far
124
discovered_properties: Set[int] = dataclasses.field(default_factory=set)
125
126
127
class AgentActions:
128
"""
129
This is the AgentActions class. It interacts with and makes changes to the environment.
130
"""
131
132
def __init__(self, environment: model.Environment, throws_on_invalid_actions=True):
133
"""
134
AgentActions Constructor
135
136
environment - CyberBattleSim environment parameters
137
throws_on_invalid_actions - whether to raise an exception when executing an invalid action (e.g., running an attack from a node that's not owned)
138
if set to False a negative reward is returned instead.
139
140
"""
141
self._environment = environment
142
self._gathered_credentials: Set[model.CredentialID] = set()
143
self._discovered_nodes: "OrderedDict[model.NodeID, NodeTrackingInformation]" = OrderedDict()
144
self._throws_on_invalid_actions = throws_on_invalid_actions
145
146
# List of all special tags indicating a privilege level reached on a node
147
self.privilege_tags = [model.PrivilegeEscalation(p).tag for p in list(PrivilegeLevel)]
148
149
# Mark all owned nodes as discovered
150
for i, node in environment.nodes():
151
if node.agent_installed:
152
self.__mark_node_as_owned(i, PrivilegeLevel.LocalUser)
153
154
def discovered_nodes(self) -> Iterator[Tuple[model.NodeID, model.NodeInfo]]:
155
for node_id in self._discovered_nodes:
156
yield (node_id, self._environment.get_node(node_id))
157
158
def _check_prerequisites(self, target: model.NodeID, vulnerability: model.VulnerabilityInfo) -> bool:
159
"""
160
This is a quick helper function to check the prerequisites to see if
161
they match the ones supplied.
162
"""
163
node: model.NodeInfo = self._environment.network.nodes[target]["data"]
164
node_flags = node.properties
165
expr = vulnerability.precondition.expression
166
167
true_value = ALGEBRA.parse("true")
168
false_value = ALGEBRA.parse("false")
169
mapping = {i: true_value if str(i) in node_flags else false_value for i in expr.get_symbols()}
170
is_true: bool = cast(boolean.Expression, expr.subs(mapping)).simplify() == true_value
171
return is_true
172
173
def list_vulnerabilities_in_target(
174
self,
175
target: model.NodeID,
176
type_filter: Optional[model.VulnerabilityType] = None,
177
) -> List[model.VulnerabilityID]:
178
"""
179
This function takes a model.NodeID for the target to be scanned
180
and returns a list of vulnerability IDs.
181
It checks each vulnerability in the library against the the properties of a given node
182
and determines which vulnerabilities it has.
183
"""
184
if not self._environment.network.has_node(target):
185
raise ValueError(f"invalid node id '{target}'")
186
187
target_node_data: model.NodeInfo = self._environment.get_node(target)
188
189
global_vuln: Set[model.VulnerabilityID] = {
190
vuln_id
191
for vuln_id, vulnerability in self._environment.vulnerability_library.items()
192
if (type_filter is None or vulnerability.type == type_filter) and self._check_prerequisites(target, vulnerability)
193
}
194
195
local_vuln: Set[model.VulnerabilityID] = {
196
vuln_id
197
for vuln_id, vulnerability in target_node_data.vulnerabilities.items()
198
if (type_filter is None or vulnerability.type == type_filter) and self._check_prerequisites(target, vulnerability)
199
}
200
201
return list(global_vuln.union(local_vuln))
202
203
def __annotate_edge(
204
self,
205
source_node_id: model.NodeID,
206
target_node_id: model.NodeID,
207
new_annotation: EdgeAnnotation,
208
) -> None:
209
"""Create the edge if it does not already exist, and annotate with the maximum
210
of the existing annotation and a specified new annotation"""
211
edge_annotation = self._environment.network.get_edge_data(source_node_id, target_node_id)
212
if edge_annotation is not None:
213
if "kind" in edge_annotation:
214
new_annotation = EdgeAnnotation(max(edge_annotation["kind"].value, new_annotation.value))
215
else:
216
new_annotation = EdgeAnnotation(new_annotation.value)
217
self._environment.network.add_edge(
218
source_node_id,
219
target_node_id,
220
kind=new_annotation,
221
kind_as_float=float(new_annotation.value),
222
)
223
224
def get_discovered_properties(self, node_id: model.NodeID) -> Set[int]:
225
return self._discovered_nodes[node_id].discovered_properties
226
227
def __mark_node_as_discovered(self, node_id: model.NodeID) -> bool:
228
logger.info("discovered node: " + node_id)
229
newly_discovered = node_id not in self._discovered_nodes
230
if newly_discovered:
231
self._discovered_nodes[node_id] = NodeTrackingInformation()
232
return newly_discovered
233
234
def __mark_nodeproperties_as_discovered(self, node_id: model.NodeID, properties: List[PropertyName]):
235
properties_indices = [self._environment.identifiers.properties.index(p) for p in properties if p not in self.privilege_tags]
236
237
if node_id in self._discovered_nodes:
238
before_count = len(self._discovered_nodes[node_id].discovered_properties)
239
self._discovered_nodes[node_id].discovered_properties = self._discovered_nodes[node_id].discovered_properties.union(properties_indices)
240
else:
241
before_count = 0
242
self._discovered_nodes[node_id] = NodeTrackingInformation(discovered_properties=set(properties_indices))
243
244
newly_discovered_properties = len(self._discovered_nodes[node_id].discovered_properties) - before_count
245
return newly_discovered_properties
246
247
def __mark_allnodeproperties_as_discovered(self, node_id: model.NodeID):
248
node_info: model.NodeInfo = self._environment.network.nodes[node_id]["data"]
249
return self.__mark_nodeproperties_as_discovered(node_id, node_info.properties)
250
251
def __mark_node_as_owned(
252
self,
253
node_id: model.NodeID,
254
privilege: PrivilegeLevel = model.PrivilegeLevel.LocalUser,
255
) -> Tuple[Optional[datetime.datetime], bool]:
256
"""Mark a node as owned.
257
Return the time it was previously own (or None) and whether it was already owned.
258
"""
259
node_info = self._environment.get_node(node_id)
260
261
last_owned_at, is_currently_owned = self.__is_node_owned_history(node_id, node_info)
262
263
if not is_currently_owned:
264
if node_id not in self._discovered_nodes:
265
self._discovered_nodes[node_id] = NodeTrackingInformation()
266
node_info.agent_installed = True
267
node_info.privilege_level = model.escalate(node_info.privilege_level, privilege)
268
self._environment.network.nodes[node_id].update({"data": node_info})
269
270
self.__mark_allnodeproperties_as_discovered(node_id)
271
272
# Record that the node just got owned at the current time
273
self._discovered_nodes[node_id].last_owned_at = datetime.datetime.now()
274
275
return last_owned_at, is_currently_owned
276
277
def __mark_discovered_entities(self, reference_node: model.NodeID, outcome: model.VulnerabilityOutcome) -> Tuple[int, float, int]:
278
"""Mark discovered entities as such and return
279
the number of newly discovered nodes, their total value and the number of newly discovered credentials
280
"""
281
newly_discovered_nodes = 0
282
newly_discovered_nodes_value = 0
283
newly_discovered_credentials = 0
284
285
if isinstance(outcome, model.LeakedCredentials):
286
for credential in outcome.credentials:
287
if self.__mark_node_as_discovered(credential.node):
288
newly_discovered_nodes += 1
289
newly_discovered_nodes_value += self._environment.get_node(credential.node).value
290
291
if credential.credential not in self._gathered_credentials:
292
newly_discovered_credentials += 1
293
self._gathered_credentials.add(credential.credential)
294
295
logger.info("discovered credential: " + str(credential))
296
self.__annotate_edge(reference_node, credential.node, EdgeAnnotation.KNOWS)
297
298
elif isinstance(outcome, model.LeakedNodesId):
299
for node_id in outcome.nodes:
300
if self.__mark_node_as_discovered(node_id):
301
newly_discovered_nodes += 1
302
newly_discovered_nodes_value += self._environment.get_node(node_id).value
303
304
self.__annotate_edge(reference_node, node_id, EdgeAnnotation.KNOWS)
305
306
return (
307
newly_discovered_nodes,
308
newly_discovered_nodes_value,
309
newly_discovered_credentials,
310
)
311
312
def get_node_privilegelevel(self, node_id: model.NodeID) -> model.PrivilegeLevel:
313
"""Return the last recorded privilege level of the specified node"""
314
node_info = self._environment.get_node(node_id)
315
return node_info.privilege_level
316
317
def get_nodes_with_atleast_privilegelevel(self, level: PrivilegeLevel) -> List[model.NodeID]:
318
"""Return all nodes with at least the specified privilege level"""
319
return [n for n, info in self._environment.nodes() if info.privilege_level >= level]
320
321
def is_node_discovered(self, node_id: model.NodeID) -> bool:
322
"""Returns true if previous actions have revealed the specified node ID"""
323
return node_id in self._discovered_nodes
324
325
def __process_outcome(
326
self,
327
expected_type: VulnerabilityType,
328
vulnerability_id: VulnerabilityID,
329
node_id: model.NodeID,
330
node_info: model.NodeInfo,
331
local_or_remote: bool,
332
failed_penalty: float,
333
throw_if_vulnerability_not_present: bool,
334
) -> Tuple[bool, ActionResult]:
335
if node_info.status != model.MachineStatus.Running:
336
logger.info("target machine not in running state")
337
return False, ActionResult(reward=Penalty.MACHINE_NOT_RUNNING, outcome=None)
338
339
is_global_vulnerability = vulnerability_id in self._environment.vulnerability_library
340
is_inplace_vulnerability = vulnerability_id in node_info.vulnerabilities
341
342
if is_global_vulnerability:
343
vulnerabilities = self._environment.vulnerability_library
344
elif is_inplace_vulnerability:
345
vulnerabilities = node_info.vulnerabilities
346
else:
347
if throw_if_vulnerability_not_present:
348
raise ValueError(f"Vulnerability '{vulnerability_id}' not supported by node='{node_id}'")
349
else:
350
logger.info(f"Vulnerability '{vulnerability_id}' not supported by node '{node_id}'")
351
return False, ActionResult(reward=Penalty.SUPSPICIOUSNESS, outcome=None)
352
353
vulnerability = vulnerabilities[vulnerability_id]
354
355
outcome = vulnerability.outcome
356
357
if vulnerability.type != expected_type:
358
raise ValueError(f"vulnerability id '{vulnerability_id}' is for an attack of type {vulnerability.type}, expecting: {expected_type}")
359
360
# check vulnerability prerequisites
361
if not self._check_prerequisites(node_id, vulnerability):
362
return False, ActionResult(reward=failed_penalty, outcome=model.ExploitFailed())
363
364
reward = 0
365
366
# if the vulnerability type is a privilege escalation
367
# and if the escalation level is not already reached on that node,
368
# then add the escalation tag to the node properties
369
if isinstance(outcome, model.PrivilegeEscalation):
370
if outcome.tag in node_info.properties:
371
return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)
372
373
last_owned_at, is_currently_owned = self.__mark_node_as_owned(node_id, outcome.level)
374
375
if not last_owned_at:
376
reward += float(node_info.value)
377
378
node_info.properties.append(outcome.tag)
379
380
elif isinstance(outcome, model.LateralMove):
381
last_owned_at, is_currently_owned = self.__mark_node_as_owned(node_id)
382
383
if not last_owned_at:
384
reward += float(node_info.value)
385
386
elif isinstance(outcome, model.ProbeSucceeded):
387
for p in outcome.discovered_properties:
388
assert p in node_info.properties, f"Discovered property {p} must belong to the set of properties associated with the node."
389
390
newly_discovered_properties = self.__mark_nodeproperties_as_discovered(node_id, outcome.discovered_properties)
391
reward += newly_discovered_properties * PROPERTY_DISCOVERED_REWARD
392
393
if node_id not in self._discovered_nodes:
394
self._discovered_nodes[node_id] = NodeTrackingInformation()
395
396
lookup_key = (vulnerability_id, local_or_remote)
397
398
already_executed = lookup_key in self._discovered_nodes[node_id].last_attack
399
400
if already_executed:
401
last_time = self._discovered_nodes[node_id].last_attack[lookup_key]
402
if node_info.last_reimaging is None or last_time >= node_info.last_reimaging:
403
reward += Penalty.REPEAT
404
else:
405
reward += NEW_SUCCESSFULL_ATTACK_REWARD
406
407
self._discovered_nodes[node_id].last_attack[lookup_key] = datetime.datetime.now()
408
409
(
410
newly_discovered_nodes,
411
discovered_nodes_value,
412
newly_discovered_credentials,
413
) = self.__mark_discovered_entities(node_id, outcome)
414
415
# Note: `discovered_nodes_value` should not be added to the reward
416
# unless the discovered nodes got owned, but this case is already covered above
417
reward += newly_discovered_nodes * NODE_DISCOVERED_REWARD
418
reward += newly_discovered_credentials * CREDENTIAL_DISCOVERED_REWARD
419
420
reward -= vulnerability.cost
421
422
logger.info("GOT REWARD: " + vulnerability.reward_string)
423
return True, ActionResult(reward=reward, outcome=outcome)
424
425
def exploit_remote_vulnerability(
426
self,
427
node_id: model.NodeID,
428
target_node_id: model.NodeID,
429
vulnerability_id: model.VulnerabilityID,
430
) -> ActionResult:
431
"""
432
Attempt to exploit a remote vulnerability
433
from a source node to another node using the specified
434
vulnerability.
435
"""
436
if node_id not in self._environment.network.nodes:
437
raise ValueError(f"invalid node id '{node_id}'")
438
if target_node_id not in self._environment.network.nodes:
439
raise ValueError(f"invalid target node id '{target_node_id}'")
440
441
source_node_info: model.NodeInfo = self._environment.get_node(node_id)
442
target_node_info: model.NodeInfo = self._environment.get_node(target_node_id)
443
444
if not source_node_info.agent_installed:
445
if self._throws_on_invalid_actions:
446
raise ValueError("Agent does not owned the source node '" + node_id + "'")
447
else:
448
return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)
449
450
if target_node_id not in self._discovered_nodes:
451
if self._throws_on_invalid_actions:
452
raise ValueError("Agent has not discovered the target node '" + target_node_id + "'")
453
else:
454
return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)
455
456
succeeded, result = self.__process_outcome(
457
model.VulnerabilityType.REMOTE,
458
vulnerability_id,
459
target_node_id,
460
target_node_info,
461
local_or_remote=False,
462
failed_penalty=Penalty.FAILED_REMOTE_EXPLOIT,
463
# We do not throw if the vulnerability is missing in order to
464
# allow agent attempts to explore potential remote vulnerabilities
465
throw_if_vulnerability_not_present=False,
466
)
467
468
if succeeded:
469
self.__annotate_edge(node_id, target_node_id, EdgeAnnotation.REMOTE_EXPLOIT)
470
471
return result
472
473
def exploit_local_vulnerability(self, node_id: model.NodeID, vulnerability_id: model.VulnerabilityID) -> ActionResult:
474
"""
475
This function exploits a local vulnerability on a node
476
it takes a nodeID for the target and a vulnerability ID.
477
478
It returns either a vulnerabilityoutcome object or None
479
"""
480
graph = self._environment.network
481
if node_id not in graph.nodes:
482
raise ValueError(f"invalid node id '{node_id}'")
483
484
node_info = self._environment.get_node(node_id)
485
486
if not node_info.agent_installed:
487
if self._throws_on_invalid_actions:
488
raise ValueError(f"Agent does not owned the node '{node_id}'")
489
else:
490
return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)
491
492
succeeded, result = self.__process_outcome(
493
model.VulnerabilityType.LOCAL,
494
vulnerability_id,
495
node_id,
496
node_info,
497
local_or_remote=True,
498
failed_penalty=Penalty.LOCAL_EXPLOIT_FAILED,
499
throw_if_vulnerability_not_present=False,
500
)
501
502
return result
503
504
def __is_passing_firewall_rules(self, rules: List[model.FirewallRule], port_name: model.PortName) -> bool:
505
"""Determine if traffic on the specified port is permitted by the specified sets of firewall rules"""
506
for rule in rules:
507
if rule.port == port_name:
508
if rule.permission == model.RulePermission.ALLOW:
509
return True
510
else:
511
logger.debug(f"BLOCKED TRAFFIC - PORT '{port_name}' Reason: " + rule.reason)
512
return False
513
514
logger.debug(f"BLOCKED TRAFFIC - PORT '{port_name}' - Reason: no rule defined for this port.")
515
return False
516
517
def __is_node_owned_history(self, target_node_id, target_node_data):
518
"""Returns the last time the node got owned and whether it is still currently owned."""
519
last_previously_owned_at = self._discovered_nodes[target_node_id].last_owned_at if target_node_id in self._discovered_nodes else None
520
521
is_currently_owned = last_previously_owned_at is not None and (target_node_data.last_reimaging is None or last_previously_owned_at >= target_node_data.last_reimaging)
522
return last_previously_owned_at, is_currently_owned
523
524
def connect_to_remote_machine(
525
self,
526
source_node_id: model.NodeID,
527
target_node_id: model.NodeID,
528
port_name: model.PortName,
529
credential: model.CredentialID,
530
) -> ActionResult:
531
"""
532
This function connects to a remote machine with credential as opposed to via an exploit.
533
It takes a NodeId for the source machine, a NodeID for the target Machine, and a credential object
534
for the credential.
535
"""
536
graph = self._environment.network
537
if source_node_id not in graph.nodes:
538
raise ValueError(f"invalid node id '{source_node_id}'")
539
if target_node_id not in graph.nodes:
540
raise ValueError(f"invalid node id '{target_node_id}''")
541
542
target_node = self._environment.get_node(target_node_id)
543
source_node = self._environment.get_node(source_node_id)
544
# ensures that the source node is owned by the agent
545
# and that the target node is discovered
546
547
if not source_node.agent_installed:
548
if self._throws_on_invalid_actions:
549
raise ValueError(f"Agent does not owned the source node '{source_node_id}'")
550
else:
551
return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)
552
553
if target_node_id not in self._discovered_nodes:
554
if self._throws_on_invalid_actions:
555
raise ValueError(f"Agent has not discovered the target node '{target_node_id}'")
556
else:
557
return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)
558
559
if credential not in self._gathered_credentials:
560
if self._throws_on_invalid_actions:
561
raise ValueError(f"Agent has not discovered credential '{credential}'")
562
else:
563
return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)
564
565
if not self.__is_passing_firewall_rules(source_node.firewall.outgoing, port_name):
566
logger.info(f"BLOCKED TRAFFIC: source node '{source_node_id}'" + f" is blocking outgoing traffic on port '{port_name}'")
567
return ActionResult(reward=Penalty.BLOCKED_BY_LOCAL_FIREWALL, outcome=None)
568
569
if not self.__is_passing_firewall_rules(target_node.firewall.incoming, port_name):
570
logger.info(f"BLOCKED TRAFFIC: target node '{target_node_id}'" + f" is blocking incoming traffic on port '{port_name}'")
571
return ActionResult(reward=Penalty.BLOCKED_BY_REMOTE_FIREWALL, outcome=None)
572
573
target_node_is_listening = port_name in [i.name for i in target_node.services]
574
if not target_node_is_listening:
575
logger.info(f"target node '{target_node_id}' not listening on port '{port_name}'")
576
return ActionResult(reward=Penalty.SCANNING_UNOPEN_PORT, outcome=None)
577
else:
578
target_node_data: model.NodeInfo = self._environment.get_node(target_node_id)
579
580
if target_node_data.status != model.MachineStatus.Running:
581
logger.info("target machine not in running state")
582
return ActionResult(reward=Penalty.MACHINE_NOT_RUNNING, outcome=None)
583
584
# check the credentials before connecting
585
if not self._check_service_running_and_authorized(target_node_data, port_name, credential):
586
logger.info("invalid credentials supplied")
587
return ActionResult(reward=Penalty.WRONG_PASSWORD, outcome=None)
588
589
last_owned_at, is_already_owned = self.__mark_node_as_owned(target_node_id)
590
591
if is_already_owned:
592
return ActionResult(reward=Penalty.REPEAT, outcome=model.LateralMove())
593
594
if target_node_id not in self._discovered_nodes:
595
self._discovered_nodes[target_node_id] = NodeTrackingInformation()
596
597
self.__annotate_edge(source_node_id, target_node_id, EdgeAnnotation.LATERAL_MOVE)
598
599
logger.info(f"Infected node '{target_node_id}' from '{source_node_id}'" + f" via {port_name} with credential '{credential}'")
600
if target_node.owned_string:
601
logger.info("Owned message: " + target_node.owned_string)
602
603
return ActionResult(
604
reward=float(target_node_data.value) if last_owned_at is None else 0.0,
605
outcome=model.LateralMove(),
606
)
607
608
def _check_service_running_and_authorized(
609
self,
610
target_node_data: model.NodeInfo,
611
port_name: model.PortName,
612
credential: model.CredentialID,
613
) -> bool:
614
"""
615
This is a quick helper function to check the prerequisites to see if
616
they match the ones supplied.
617
"""
618
for service in target_node_data.services:
619
if service.running and service.name == port_name and credential in service.allowedCredentials:
620
return True
621
return False
622
623
def list_nodes(self) -> List[DiscoveredNodeInfo]:
624
"""Returns the list of nodes ID that were discovered or owned by the attacker."""
625
return [
626
cast(
627
DiscoveredNodeInfo,
628
{
629
"id": node_id,
630
"status": "owned" if node_info.agent_installed else "discovered",
631
},
632
)
633
for node_id, node_info in self.discovered_nodes()
634
]
635
636
def list_remote_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
637
"""Return list of all remote attacks that may be executed onto the specified node."""
638
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(node_id, model.VulnerabilityType.REMOTE)
639
return attacks
640
641
def list_local_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
642
"""Return list of all local attacks that may be executed onto the specified node."""
643
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(node_id, model.VulnerabilityType.LOCAL)
644
return attacks
645
646
def list_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:
647
"""Return list of all attacks that may be executed on the specified node."""
648
attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(node_id)
649
return attacks
650
651
def list_all_attacks(self) -> List[Dict[str, object]]:
652
"""List all possible attacks from all the nodes currently owned by the attacker"""
653
on_owned_nodes: List[Dict[str, object]] = [
654
{
655
"id": n["id"],
656
"status": n["status"],
657
"properties": self._environment.get_node(n["id"]).properties,
658
"local_attacks": self.list_local_attacks(n["id"]),
659
"remote_attacks": self.list_remote_attacks(n["id"]),
660
}
661
for n in self.list_nodes()
662
if n["status"] == "owned"
663
]
664
on_discovered_nodes: List[Dict[str, object]] = [
665
{
666
"id": n["id"],
667
"status": n["status"],
668
"local_attacks": None,
669
"remote_attacks": self.list_remote_attacks(n["id"]),
670
}
671
for n in self.list_nodes()
672
if n["status"] != "owned"
673
]
674
return on_owned_nodes + on_discovered_nodes
675
676
def print_all_attacks(self) -> None:
677
"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""
678
display(pd.DataFrame.from_dict(self.list_all_attacks()).set_index("id")) # type: ignore
679
680
681
class DefenderAgentActions:
682
"""Actions reserved to defender agents"""
683
684
# Number of steps it takes to completely reimage a node
685
REIMAGING_DURATION = 15
686
687
def __init__(self, environment: model.Environment):
688
# map nodes being reimaged to the remaining number of steps to completion
689
self.node_reimaging_progress: Dict[model.NodeID, int] = dict()
690
691
# Last calculated availability of the network
692
self.__network_availability: float = 1.0
693
694
self._environment = environment
695
696
@property
697
def network_availability(self):
698
return self.__network_availability
699
700
def reimage_node(self, node_id: model.NodeID):
701
"""Re-image a computer node"""
702
# Mark the node for re-imaging and make it unavailable until re-imaging completes
703
self.node_reimaging_progress[node_id] = self.REIMAGING_DURATION
704
705
node_info = self._environment.get_node(node_id)
706
assert node_info.reimagable, f"Node {node_id} is not re-imageable"
707
708
node_info.agent_installed = False
709
node_info.privilege_level = model.PrivilegeLevel.NoAccess
710
node_info.status = model.MachineStatus.Imaging
711
node_info.last_reimaging = datetime.datetime.now()
712
self._environment.network.nodes[node_id].update({"data": node_info})
713
714
def on_attacker_step_taken(self):
715
"""Function to be called each time a step is take in the simulation"""
716
for node_id in list(self.node_reimaging_progress.keys()):
717
remaining_steps = self.node_reimaging_progress[node_id]
718
if remaining_steps > 0:
719
self.node_reimaging_progress[node_id] -= 1
720
else:
721
logger.info(f"Machine re-imaging completed: {node_id}")
722
node_data = self._environment.get_node(node_id)
723
node_data.status = model.MachineStatus.Running
724
self.node_reimaging_progress.pop(node_id)
725
726
# Calculate the network availability metric based on machines
727
# and services that are running
728
total_node_weights = 0
729
network_node_availability = 0
730
for node_id, node_info in self._environment.nodes():
731
total_service_weights = 0
732
running_service_weights = 0
733
for service in node_info.services:
734
total_service_weights += service.sla_weight
735
running_service_weights += service.sla_weight * int(service.running)
736
737
if node_info.status == MachineStatus.Running:
738
adjusted_node_availability = (1 + running_service_weights) / (1 + total_service_weights)
739
else:
740
adjusted_node_availability = 0.0
741
742
total_node_weights += node_info.sla_weight
743
network_node_availability += adjusted_node_availability * node_info.sla_weight
744
745
self.__network_availability = network_node_availability / total_node_weights
746
assert self.__network_availability <= 1.0 and self.__network_availability >= 0.0
747
748
def override_firewall_rule(
749
self,
750
node_id: model.NodeID,
751
port_name: model.PortName,
752
incoming: bool,
753
permission: model.RulePermission,
754
):
755
node_data = self._environment.get_node(node_id)
756
757
def add_or_patch_rule(rules) -> List[FirewallRule]:
758
new_rules = []
759
has_matching_rule = False
760
for r in rules:
761
if r.port == port_name:
762
has_matching_rule = True
763
new_rules.append(FirewallRule(r.port, permission))
764
else:
765
new_rules.append(r)
766
767
if not has_matching_rule:
768
new_rules.append(model.FirewallRule(port_name, permission))
769
return new_rules
770
771
if incoming:
772
node_data.firewall.incoming = add_or_patch_rule(node_data.firewall.incoming)
773
else:
774
node_data.firewall.outgoing = add_or_patch_rule(node_data.firewall.outgoing)
775
776
def block_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):
777
return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.BLOCK)
778
779
def allow_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):
780
return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.ALLOW)
781
782
def stop_service(self, node_id: model.NodeID, port_name: model.PortName):
783
node_data = self._environment.get_node(node_id)
784
assert node_data.status == model.MachineStatus.Running, "Machine must be running to stop a service"
785
for service in node_data.services:
786
if service.name == port_name:
787
service.running = False
788
789
def start_service(self, node_id: model.NodeID, port_name: model.PortName):
790
node_data = self._environment.get_node(node_id)
791
assert node_data.status == model.MachineStatus.Running, "Machine must be running to start a service"
792
for service in node_data.services:
793
if service.name == port_name:
794
service.running = True
795
796