Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/simulation/model.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Data model for the simulation environment.
5
6
The simulation environment is given by the directed graph
7
formally defined by:
8
9
Node := NodeID x ListeningService[] x Value x Vulnerability[] x FirewallConfig
10
Edge := NodeID x NodeID x PortName
11
12
where:
13
- NodeID: string
14
- ListeningService : Name x AllowedCredentials
15
- AllowedCredentials : string[] # credential pair represented by just a
16
string ID
17
- Value : [0...100] # Intrinsic value of reaching this node
18
- Vulnerability : VulnerabilityID x Type x Precondition x Outcome x Rates
19
- VulnerabilityID : string
20
- Rates : ProbingDetectionRate x ExploitDetectionRate x SuccessRate
21
- FirewallConfig: {
22
outgoing : FirwallRule[]
23
incoming : FirwallRule [] }
24
- FirewallRule: PortName x { ALLOW, BLOCK }
25
"""
26
27
from datetime import datetime
28
from typing import NamedTuple, List, Dict, Optional, Union, Tuple, Iterator
29
import dataclasses
30
from dataclasses import dataclass, field
31
import matplotlib.pyplot as plt # type:ignore
32
from enum import Enum, IntEnum
33
from boolean import boolean
34
import networkx as nx
35
import yaml
36
import random
37
38
import matplotlib # type: ignore
39
40
matplotlib.use("Agg")
41
42
VERSION_TAG = "0.1.0"
43
44
ALGEBRA = boolean.BooleanAlgebra()
45
46
# Type alias for identifiers
47
NodeID = str
48
49
# A unique identifier
50
ID = str
51
52
# a (login,password/token) credential pair is abstracted as just a unique
53
# string identifier
54
CredentialID = str
55
56
# Intrinsic value of a reaching a given node in [0,100]
57
NodeValue = int
58
59
60
PortName = str
61
62
63
@dataclass
64
class ListeningService:
65
"""A service port on a given node accepting connection initiated
66
with the specified allowed credentials"""
67
68
# Name of the port the service is listening to
69
name: PortName
70
# credential allowed to authenticate with the service
71
allowedCredentials: List[CredentialID] = dataclasses.field(default_factory=list)
72
# whether the service is running or stopped
73
running: bool = True
74
# Weight used to evaluate the cost of not running the service
75
sla_weight = 1.0
76
77
78
x = ListeningService(name="d")
79
VulnerabilityID = str
80
81
# Probability rate
82
Probability = float
83
84
# The name of a node property indicating the presence of a
85
# service, component, feature or vulnerability on a given node.
86
PropertyName = str
87
88
89
class Rates(NamedTuple):
90
"""Probabilities associated with a given vulnerability"""
91
92
probingDetectionRate: Probability = 0.0
93
exploitDetectionRate: Probability = 0.0
94
successRate: Probability = 1.0
95
96
97
class VulnerabilityType(Enum):
98
"""Is the vulnerability exploitable locally or remotely?"""
99
100
LOCAL = 1
101
REMOTE = 2
102
103
104
class PrivilegeLevel(IntEnum):
105
"""Access privilege level on a given node"""
106
107
NoAccess = 0
108
LocalUser = 1
109
Admin = 2
110
System = 3
111
MAXIMUM = 3
112
113
114
def escalate(current_level, escalation_level: PrivilegeLevel) -> PrivilegeLevel:
115
return PrivilegeLevel(max(int(current_level), int(escalation_level)))
116
117
118
class VulnerabilityOutcome:
119
"""Outcome of exploiting a given vulnerability"""
120
121
122
class LateralMove(VulnerabilityOutcome):
123
"""Lateral movement to the target node"""
124
125
success: bool
126
127
128
class CustomerData(VulnerabilityOutcome):
129
"""Access customer data on target node"""
130
131
132
class PrivilegeEscalation(VulnerabilityOutcome):
133
"""Privilege escalation outcome"""
134
135
def __init__(self, level: PrivilegeLevel):
136
self.level = level
137
138
@property
139
def tag(self):
140
"""Escalation tag that gets added to node properties when
141
the escalation level is reached for that node"""
142
return f"privilege_{self.level}"
143
144
145
class SystemEscalation(PrivilegeEscalation):
146
"""Escalation to SYSTEM privileges"""
147
148
def __init__(self):
149
super().__init__(PrivilegeLevel.System)
150
151
152
class AdminEscalation(PrivilegeEscalation):
153
"""Escalation to local administrator privileges"""
154
155
def __init__(self):
156
super().__init__(PrivilegeLevel.Admin)
157
158
159
class ProbeSucceeded(VulnerabilityOutcome):
160
"""Probing succeeded"""
161
162
def __init__(self, discovered_properties: List[PropertyName]):
163
self.discovered_properties = discovered_properties
164
165
166
class ProbeFailed(VulnerabilityOutcome):
167
"""Probing failed"""
168
169
170
class ExploitFailed(VulnerabilityOutcome):
171
"""This is for situations where the exploit fails"""
172
173
174
class CachedCredential(NamedTuple):
175
"""Encodes a machine-port-credential triplet"""
176
177
node: NodeID
178
port: PortName
179
credential: CredentialID
180
181
182
class LeakedCredentials(VulnerabilityOutcome):
183
"""A set of credentials obtained by exploiting a vulnerability"""
184
185
credentials: List[CachedCredential]
186
187
def __init__(self, credentials: List[CachedCredential]):
188
self.credentials = credentials
189
190
191
class LeakedNodesId(VulnerabilityOutcome):
192
"""A set of node IDs obtained by exploiting a vulnerability"""
193
194
def __init__(self, nodes: List[NodeID]):
195
self.nodes = nodes
196
197
198
VulnerabilityOutcomes = Union[LeakedCredentials, LeakedNodesId, PrivilegeEscalation, AdminEscalation, SystemEscalation, CustomerData, LateralMove, ExploitFailed]
199
200
201
class AttackResult:
202
"""The result of attempting a specific attack (either local or remote)"""
203
204
success: bool
205
expected_outcome: Union[VulnerabilityOutcomes, None]
206
207
208
class Precondition:
209
"""A predicate logic expression defining the condition under which a given
210
feature or vulnerability is present or not.
211
The symbols used in the expression refer to properties associated with
212
the corresponding node.
213
E.g. 'Win7', 'Server', 'IISInstalled', 'SQLServerInstalled',
214
'AntivirusInstalled' ...
215
"""
216
217
expression: boolean.Expression
218
219
def __init__(self, expression: Union[boolean.Expression, str]):
220
if isinstance(expression, boolean.Expression):
221
self.expression = expression
222
else:
223
self.expression = ALGEBRA.parse(expression)
224
225
226
class VulnerabilityInfo(NamedTuple):
227
"""Definition of a known vulnerability"""
228
229
# an optional description of what the vulnerability is
230
description: str
231
# type of vulnerability
232
type: VulnerabilityType
233
# what happens when successfully exploiting the vulnerability
234
outcome: VulnerabilityOutcome
235
# a boolean expression over a node's properties determining if the
236
# vulnerability is present or not
237
precondition: Precondition = Precondition("true")
238
# rates of success/failure associated with this vulnerability
239
rates: Rates = Rates()
240
# points to information about the vulnerability
241
URL: str = ""
242
# some cost associated with exploiting this vulnerability (e.g.
243
# brute force more costly than dumping credentials)
244
cost: float = 1.0
245
# a string displayed when the vulnerability is successfully exploited
246
reward_string: str = ""
247
248
249
# A dictionary storing information about all supported vulnerabilities
250
# or features supported by the simulation.
251
# This is to be used as a global dictionary pre-populated before
252
# starting the simulation and estimated from real-world data.
253
VulnerabilityLibrary = Dict[VulnerabilityID, VulnerabilityInfo]
254
255
256
class RulePermission(Enum):
257
"""Determine if a rule is blocks or allows traffic"""
258
259
ALLOW = 0
260
BLOCK = 1
261
262
263
@dataclass(frozen=True)
264
class FirewallRule:
265
"""A firewall rule"""
266
267
# A port name
268
port: PortName
269
# permission on this port
270
permission: RulePermission
271
# An optional reason for the block/allow rule
272
reason: str = ""
273
274
275
@dataclass
276
class FirewallConfiguration:
277
"""Firewall configuration on a given node.
278
Determine if traffic should be allowed or specifically blocked
279
on a given port for outgoing and incoming traffic.
280
The rules are process in order: the first rule matching a given
281
port is applied and the rest are ignored.
282
283
Port that are not listed in the configuration
284
are assumed to be blocked. (Adding an explicit block rule
285
can still be useful to give a reason for the block.)
286
"""
287
288
outgoing: List[FirewallRule] = field(
289
repr=True,
290
default_factory=lambda: [
291
FirewallRule("RDP", RulePermission.ALLOW),
292
FirewallRule("SSH", RulePermission.ALLOW),
293
FirewallRule("HTTPS", RulePermission.ALLOW),
294
FirewallRule("HTTP", RulePermission.ALLOW),
295
],
296
)
297
incoming: List[FirewallRule] = field(
298
repr=True,
299
default_factory=lambda: [
300
FirewallRule("RDP", RulePermission.ALLOW),
301
FirewallRule("SSH", RulePermission.ALLOW),
302
FirewallRule("HTTPS", RulePermission.ALLOW),
303
FirewallRule("HTTP", RulePermission.ALLOW),
304
],
305
)
306
307
308
class MachineStatus(Enum):
309
"""Machine running status"""
310
311
Stopped = 0
312
Running = 1
313
Imaging = 2
314
315
316
@dataclass
317
class NodeInfo:
318
"""A computer node in the enterprise network"""
319
320
# List of port/protocol the node is listening to
321
services: List[ListeningService]
322
# List of known vulnerabilities for the node
323
vulnerabilities: VulnerabilityLibrary = dataclasses.field(default_factory=dict)
324
# Intrinsic value of the node (translates into a reward if the node gets owned)
325
value: NodeValue = 0
326
# Properties of the nodes, some of which can imply further vulnerabilities
327
properties: List[PropertyName] = dataclasses.field(default_factory=list)
328
# Fireall configuration of the node
329
firewall: FirewallConfiguration = dataclasses.field(default_factory=FirewallConfiguration)
330
# Attacker agent installed on the node? (aka the node is 'pwned')
331
agent_installed: bool = False
332
# Esclation level
333
privilege_level: PrivilegeLevel = PrivilegeLevel.NoAccess
334
# Can the node be re-imaged by a defender agent?
335
reimagable: bool = True
336
# Last time the node was reimaged
337
last_reimaging: Optional[datetime] = None
338
# String displayed when the node gets owned
339
owned_string: str = ""
340
# Machine status: running or stopped
341
status = MachineStatus.Running
342
# Relative node weight used to calculate the cost of stopping this machine
343
# or its services
344
sla_weight: float = 1.0
345
346
347
class Identifiers(NamedTuple):
348
"""Define the global set of identifiers used
349
in the definition of a given environment.
350
Such set defines a common vocabulary possibly
351
shared across multiple environments, thus
352
ensuring a consistent numbering convention
353
that a machine learniong model can learn from."""
354
355
# Array of all possible node property identifiers
356
properties: List[PropertyName] = []
357
# Array of all possible port names
358
ports: List[PortName] = ["Null"]
359
# Array of all possible local vulnerabilities names
360
local_vulnerabilities: List[VulnerabilityID] = []
361
# Array of all possible remote vulnerabilities names
362
remote_vulnerabilities: List[VulnerabilityID] = []
363
364
365
def iterate_network_nodes(network: nx.graph.Graph) -> Iterator[Tuple[NodeID, NodeInfo]]:
366
"""Iterates over the nodes in the network"""
367
for nodeid, nodevalue in network.nodes.items():
368
node_data: NodeInfo = nodevalue["data"]
369
yield nodeid, node_data
370
371
372
# NOTE: Using `NameTuple` instead of `dataclass` breaks deserialization
373
# with PyYaml 2.8.1 due to a new recrusive references to the networkx graph in the field
374
# edges: !!python/object:networkx.classes.reportviews.EdgeView
375
# _adjdict: *id018
376
# _graph: *id019
377
@dataclass
378
class Environment:
379
"""The static graph defining the network of computers"""
380
381
network: nx.DiGraph
382
vulnerability_library: VulnerabilityLibrary
383
identifiers: Identifiers
384
creationTime: datetime = datetime.utcnow()
385
lastModified: datetime = datetime.utcnow()
386
# a version tag indicating the environment schema version
387
version: str = VERSION_TAG
388
389
def nodes(self) -> Iterator[Tuple[NodeID, NodeInfo]]:
390
"""Iterates over the nodes in the network"""
391
return iterate_network_nodes(self.network)
392
393
def get_node(self, node_id: NodeID) -> NodeInfo:
394
"""Retrieve info for the node with the specified ID"""
395
node_info: NodeInfo = self.network.nodes[node_id]["data"]
396
return node_info
397
398
def plot_environment_graph(self) -> None:
399
"""Plot the full environment graph"""
400
nx.draw(self.network, with_labels=True, node_color=[n["data"].value for i, n in self.network.nodes.items()], cmap=plt.cm.Oranges) # type:ignore
401
402
403
def create_network(nodes: Dict[NodeID, NodeInfo]) -> nx.DiGraph:
404
"""Create a network with a set of nodes and no edges"""
405
graph = nx.DiGraph()
406
graph.add_nodes_from([(k, {"data": v}) for (k, v) in list(nodes.items())])
407
return graph
408
409
410
# Helpers to infer constants from an environment
411
412
413
def collect_ports_from_vuln(vuln: VulnerabilityInfo) -> List[PortName]:
414
"""Returns all the port named referenced in a given vulnerability"""
415
if isinstance(vuln.outcome, LeakedCredentials):
416
return [c.port for c in vuln.outcome.credentials]
417
else:
418
return []
419
420
421
def collect_vulnerability_ids_from_nodes_bytype(nodes: Iterator[Tuple[NodeID, NodeInfo]], global_vulnerabilities: VulnerabilityLibrary, type: VulnerabilityType) -> List[VulnerabilityID]:
422
"""Collect and return all IDs of all vulnerability of the specified type
423
that are referenced in a given set of nodes and vulnerability library
424
"""
425
return sorted(list({id for _, node_info in nodes for id, v in node_info.vulnerabilities.items() if v.type == type}.union(id for id, v in global_vulnerabilities.items() if v.type == type)))
426
427
428
def collect_properties_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]]) -> List[PropertyName]:
429
"""Collect and return sorted list of all property names used in a given set of nodes"""
430
return sorted({p for _, node_info in nodes for p in node_info.properties})
431
432
433
def collect_ports_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]], vulnerability_library: VulnerabilityLibrary) -> List[PortName]:
434
"""Collect and return all port names used in a given set of nodes
435
and global vulnerability library"""
436
return sorted(
437
list(
438
{port for _, v in vulnerability_library.items() for port in collect_ports_from_vuln(v)}.union(
439
{port for _, node_info in nodes for _, v in node_info.vulnerabilities.items() for port in collect_ports_from_vuln(v)}.union(
440
{service.name for _, node_info in nodes for service in node_info.services}
441
)
442
)
443
)
444
)
445
446
447
def collect_ports_from_environment(environment: Environment) -> List[PortName]:
448
"""Collect and return all port names used in a given environment"""
449
return collect_ports_from_nodes(environment.nodes(), environment.vulnerability_library)
450
451
452
def infer_constants_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]], vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:
453
"""Infer global environment constants from a given network"""
454
return Identifiers(
455
properties=collect_properties_from_nodes(nodes),
456
ports=collect_ports_from_nodes(nodes, vulnerabilities),
457
local_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(nodes, vulnerabilities, VulnerabilityType.LOCAL),
458
remote_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(nodes, vulnerabilities, VulnerabilityType.REMOTE),
459
)
460
461
462
def infer_constants_from_network(network: nx.Graph, vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:
463
"""Infer global environment constants from a given network"""
464
return infer_constants_from_nodes(iterate_network_nodes(network), vulnerabilities)
465
466
467
# Network creation
468
469
# A sample set of envrionment constants
470
SAMPLE_IDENTIFIERS = Identifiers(
471
ports=["RDP", "SSH", "SMB", "HTTP", "HTTPS", "WMI", "SQL"], properties=["Windows", "Linux", "HyperV-VM", "Azure-VM", "Win7", "Win10", "PortRDPOpen", "GuestAccountEnabled"]
472
)
473
474
475
def assign_random_labels(graph: nx.DiGraph, vulnerabilities: VulnerabilityLibrary = dict([]), identifiers: Identifiers = SAMPLE_IDENTIFIERS) -> nx.DiGraph:
476
"""Create an envrionment network by randomly assigning node information
477
(properties, firewall configuration, vulnerabilities)
478
to the nodes of a given graph structure"""
479
480
# convert node IDs to string
481
graph = nx.relabel_nodes(graph, {i: str(i) for i in graph.nodes})
482
483
def create_random_firewall_configuration() -> FirewallConfiguration:
484
return FirewallConfiguration(
485
outgoing=[FirewallRule(port=p, permission=RulePermission.ALLOW) for p in random.sample(identifiers.ports, k=random.randint(0, len(identifiers.ports)))],
486
incoming=[FirewallRule(port=p, permission=RulePermission.ALLOW) for p in random.sample(identifiers.ports, k=random.randint(0, len(identifiers.ports)))],
487
)
488
489
def create_random_properties() -> List[PropertyName]:
490
return list(random.sample(identifiers.properties, k=random.randint(0, len(identifiers.properties))))
491
492
def pick_random_global_vulnerabilities() -> VulnerabilityLibrary:
493
count = random.random()
494
return {k: v for (k, v) in vulnerabilities.items() if random.random() > count}
495
496
def add_leak_neighbors_vulnerability(library: VulnerabilityLibrary, node_id: NodeID) -> None:
497
"""Create a vulnerability for each node that reveals its immediate neighbors"""
498
neighbors = {t for (s, t) in graph.edges() if s == node_id}
499
if len(neighbors) > 0:
500
library["RecentlyAccessedMachines"] = VulnerabilityInfo(description="AzureVM info, including public IP address", type=VulnerabilityType.LOCAL, outcome=LeakedNodesId(list(neighbors)))
501
502
def create_random_vulnerabilities(node_id: NodeID) -> VulnerabilityLibrary:
503
library = pick_random_global_vulnerabilities()
504
add_leak_neighbors_vulnerability(library, node_id)
505
return library
506
507
# Pick a random node as the agent entry node
508
entry_node_index = random.randrange(len(graph.nodes))
509
entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]
510
graph.nodes[entry_node_id].clear()
511
node_data = NodeInfo(
512
services=[],
513
value=0,
514
properties=create_random_properties(),
515
vulnerabilities=create_random_vulnerabilities(entry_node_id),
516
firewall=create_random_firewall_configuration(),
517
agent_installed=True,
518
reimagable=False,
519
privilege_level=PrivilegeLevel.Admin,
520
)
521
graph.nodes[entry_node_id].update({"data": node_data})
522
523
def create_random_node_data(node_id: NodeID) -> NodeInfo:
524
return NodeInfo(
525
services=[],
526
value=random.randint(0, 100),
527
properties=create_random_properties(),
528
vulnerabilities=create_random_vulnerabilities(node_id),
529
firewall=create_random_firewall_configuration(),
530
agent_installed=False,
531
privilege_level=PrivilegeLevel.NoAccess,
532
)
533
534
for node in list(graph.nodes):
535
if node != entry_node_id:
536
graph.nodes[node].clear()
537
graph.nodes[node].update({"data": create_random_node_data(node)})
538
539
return graph
540
541
542
# Serialization
543
544
545
def setup_yaml_serializer() -> None:
546
"""Setup a clean YAML formatter for object of type Environment."""
547
yaml.add_representer(Precondition, lambda dumper, data: dumper.represent_scalar("!BooleanExpression", str(data.expression))) # type: ignore
548
yaml.SafeLoader.add_constructor("!BooleanExpression", lambda loader, expression: Precondition(loader.construct_scalar(expression))) # type: ignore
549
yaml.add_constructor("!BooleanExpression", lambda loader, expression: Precondition(loader.construct_scalar(expression))) # type: ignore
550
551
yaml.add_representer(VulnerabilityType, lambda dumper, data: dumper.represent_scalar("!VulnerabilityType", str(data.name))) # type: ignore
552
553
yaml.SafeLoader.add_constructor("!VulnerabilityType", lambda loader, expression: VulnerabilityType[loader.construct_scalar(expression)]) # type: ignore
554
yaml.add_constructor("!VulnerabilityType", lambda loader, expression: VulnerabilityType[loader.construct_scalar(expression)]) # type: ignore
555
556