Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/simulation/actions_test.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""
5
This is the set of tests for actions.py which implements the actions an agent can take
6
in this simulation.
7
"""
8
9
import random
10
from datetime import datetime
11
from typing import Dict, List
12
13
import pytest
14
import networkx as nx
15
16
from . import model, actions
17
18
ADMINTAG = model.AdminEscalation().tag
19
SYSTEMTAG = model.SystemEscalation().tag
20
21
# pylint: disable=redefined-outer-name, protected-access
22
Fixture = actions.AgentActions
23
24
empty_vuln_dict: Dict[model.VulnerabilityID, model.VulnerabilityInfo] = {}
25
SINGLE_VULNERABILITIES = {
26
"UACME61": model.VulnerabilityInfo(
27
description="UACME UAC bypass #61",
28
type=model.VulnerabilityType.LOCAL,
29
URL="https://github.com/hfiref0x/UACME",
30
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
31
outcome=model.AdminEscalation(),
32
rates=model.Rates(0, 0.2, 1.0),
33
)
34
}
35
36
# temporary vuln dictionary for development purposes only.
37
# Remove once the full list of vulnerabilities is put together
38
# here we'll have 1 UAC bypass, 1 credential dump, and 1 remote infection vulnerability
39
SAMPLE_VULNERABILITIES = {
40
"UACME61": model.VulnerabilityInfo(
41
description="UACME UAC bypass #61",
42
type=model.VulnerabilityType.LOCAL,
43
URL="https://github.com/hfiref0x/UACME",
44
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
45
outcome=model.AdminEscalation(),
46
rates=model.Rates(0, 0.2, 1.0),
47
),
48
"UACME67": model.VulnerabilityInfo(
49
description="UACME UAC bypass #67 (fake system escalation) ",
50
type=model.VulnerabilityType.LOCAL,
51
URL="https://github.com/hfiref0x/UACME",
52
precondition=model.Precondition(f"Windows&Win10&(~({ADMINTAG}|{SYSTEMTAG}))"),
53
outcome=model.SystemEscalation(),
54
rates=model.Rates(0, 0.2, 1.0),
55
),
56
"MimikatzLogonpasswords": model.VulnerabilityInfo(
57
description="Mimikatz sekurlsa::logonpasswords.",
58
type=model.VulnerabilityType.LOCAL,
59
URL="https://github.com/gentilkiwi/mimikatz",
60
precondition=model.Precondition(f"Windows&({ADMINTAG}|{SYSTEMTAG})"),
61
outcome=model.LeakedCredentials([]),
62
rates=model.Rates(0, 1.0, 1.0),
63
),
64
"RDPBF": model.VulnerabilityInfo(
65
description="RDP Brute Force",
66
type=model.VulnerabilityType.REMOTE,
67
URL="https://attack.mitre.org/techniques/T1110/",
68
precondition=model.Precondition("Windows&PortRDPOpen"),
69
outcome=model.LateralMove(),
70
rates=model.Rates(0, 0.2, 1.0),
71
cost=1.0,
72
),
73
}
74
75
ENV_IDENTIFIERS = model.Identifiers(
76
local_vulnerabilities=["UACME61", "UACME67", "MimikatzLogonpasswords", "UACME61"],
77
remote_vulnerabilities=["RDPBF"],
78
ports=["RDP", "HTTP", "HTTPS", "SSH"],
79
properties=[
80
"Linux",
81
"PortSSHOpen",
82
"PortSQLOpen",
83
"Windows",
84
"Win10",
85
"PortRDPOpen",
86
"PortHTTPOpen",
87
"PortHTTPsOpen",
88
"SharepointLeakingPassword",
89
],
90
)
91
92
93
def sample_random_firwall_configuration() -> model.FirewallConfiguration:
94
"""Sample a random firewall set of rules"""
95
return model.FirewallConfiguration(
96
outgoing=[
97
model.FirewallRule(p, permission=model.RulePermission.ALLOW)
98
for p in random.choices(
99
ENV_IDENTIFIERS.properties,
100
k=random.randint(0, len(ENV_IDENTIFIERS.properties)),
101
)
102
],
103
incoming=[
104
model.FirewallRule(p, permission=model.RulePermission.ALLOW)
105
for p in random.choices(
106
ENV_IDENTIFIERS.properties,
107
k=random.randint(0, len(ENV_IDENTIFIERS.properties)),
108
)
109
],
110
)
111
112
113
# temporary info for a single node network
114
SINGLE_NODE = {
115
"a": model.NodeInfo(
116
services=[
117
model.ListeningService("RDP"),
118
model.ListeningService("HTTP"),
119
model.ListeningService("HTTPS"),
120
],
121
value=70,
122
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
123
firewall=sample_random_firwall_configuration(),
124
agent_installed=False,
125
)
126
}
127
128
# temporary info for 4 nodes
129
# a is a windows web server, b is linux SQL server, c is a windows workstation,
130
# and dc is a domain controller
131
NODES = {
132
"a": model.NodeInfo(
133
services=[
134
model.ListeningService("RDP"),
135
model.ListeningService("HTTP"),
136
model.ListeningService("HTTPS"),
137
],
138
value=70,
139
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
140
vulnerabilities=dict(
141
ListNeighbors=model.VulnerabilityInfo(
142
description="reveal other nodes",
143
type=model.VulnerabilityType.LOCAL,
144
outcome=model.LeakedNodesId(nodes=["b", "c", "dc"]),
145
),
146
DumpCreds=model.VulnerabilityInfo(
147
description="leaking some creds",
148
type=model.VulnerabilityType.LOCAL,
149
outcome=model.LeakedCredentials(
150
[
151
model.CachedCredential("Sharepoint", "HTTPS", "ADPrincipalCreds"),
152
model.CachedCredential("Sharepoint", "HTTPS", "cred"),
153
]
154
),
155
),
156
),
157
agent_installed=True,
158
),
159
"b": model.NodeInfo(
160
services=[model.ListeningService("SSH"), model.ListeningService("SQL")],
161
value=80,
162
properties=list(["Linux", "PortSSHOpen", "PortSQLOpen"]),
163
agent_installed=False,
164
),
165
"c": model.NodeInfo(
166
services=[
167
model.ListeningService("RDP"),
168
model.ListeningService("HTTP"),
169
model.ListeningService("HTTPS"),
170
],
171
value=40,
172
properties=list(["Windows", "Win10", "PortRDPOpen", "PortHTTPOpen", "PortHTTPsOpen"]),
173
agent_installed=True,
174
),
175
"dc": model.NodeInfo(
176
services=[model.ListeningService("RDP"), model.ListeningService("WMI")],
177
value=100,
178
properties=list(["Windows", "Win10", "PortRDPOpen", "PortWMIOpen"]),
179
agent_installed=False,
180
),
181
"Sharepoint": model.NodeInfo(
182
services=[model.ListeningService("HTTPS", allowedCredentials=["ADPrincipalCreds"])],
183
value=100,
184
properties=["SharepointLeakingPassword"],
185
firewall=model.FirewallConfiguration(
186
incoming=[
187
model.FirewallRule(port="SSH", permission=model.RulePermission.ALLOW),
188
model.FirewallRule(port="HTTPS", permission=model.RulePermission.ALLOW),
189
model.FirewallRule(port="HTTP", permission=model.RulePermission.ALLOW),
190
model.FirewallRule(port="RDP", permission=model.RulePermission.BLOCK),
191
],
192
outgoing=[],
193
),
194
vulnerabilities=dict(
195
ScanSharepointParentDirectory=model.VulnerabilityInfo(
196
description="Navigate to SharePoint site, browse parent " "directory",
197
type=model.VulnerabilityType.REMOTE,
198
outcome=model.LeakedCredentials(
199
credentials=[
200
model.CachedCredential(
201
node="AzureResourceManager",
202
port="HTTPS",
203
credential="ADPrincipalCreds",
204
)
205
]
206
),
207
rates=model.Rates(successRate=1.0),
208
cost=1.0,
209
)
210
),
211
),
212
}
213
214
215
# Define an environment from this graph
216
ENV = model.Environment(
217
network=model.create_network(NODES),
218
vulnerability_library=dict([]),
219
identifiers=ENV_IDENTIFIERS,
220
creationTime=datetime.utcnow(),
221
lastModified=datetime.utcnow(),
222
)
223
224
225
@pytest.fixture
226
def actions_on_empty_environment() -> actions.AgentActions:
227
"""
228
the test fixtures to reduce the amount of overhead
229
This fixture will provide us with an empty environment.
230
"""
231
egraph = nx.empty_graph(0, create_using=nx.DiGraph())
232
env = model.Environment(
233
network=egraph,
234
version=model.VERSION_TAG,
235
vulnerability_library=SAMPLE_VULNERABILITIES,
236
identifiers=ENV_IDENTIFIERS,
237
creationTime=datetime.utcnow(),
238
lastModified=datetime.utcnow(),
239
)
240
return actions.AgentActions(env)
241
242
243
@pytest.fixture
244
def actions_on_single_node_environment() -> actions.AgentActions:
245
"""
246
This fixture will provide us with a single node environment
247
"""
248
env = model.Environment(
249
network=model.create_network(SINGLE_NODE),
250
version=model.VERSION_TAG,
251
vulnerability_library=SAMPLE_VULNERABILITIES,
252
identifiers=ENV_IDENTIFIERS,
253
creationTime=datetime.utcnow(),
254
lastModified=datetime.utcnow(),
255
)
256
return actions.AgentActions(env)
257
258
259
@pytest.fixture
260
def actions_on_simple_environment() -> actions.AgentActions:
261
"""
262
This fixture will provide us with a 4 node environment environment.
263
simulating three workstations connected to a single server
264
"""
265
env = model.Environment(
266
network=model.create_network(NODES),
267
version=model.VERSION_TAG,
268
vulnerability_library=SAMPLE_VULNERABILITIES,
269
identifiers=ENV_IDENTIFIERS,
270
creationTime=datetime.utcnow(),
271
lastModified=datetime.utcnow(),
272
)
273
return actions.AgentActions(env)
274
275
276
def test_list_vulnerabilities_function(actions_on_single_node_environment: Fixture, actions_on_simple_environment: Fixture) -> None:
277
"""
278
This function will test the list_vulnerabilities function from the
279
AgentActions class in actions.py
280
"""
281
# test on an environment with a single node
282
single_node_results: List[model.VulnerabilityID] = []
283
single_node_results = actions_on_single_node_environment.list_vulnerabilities_in_target("a")
284
assert len(single_node_results) == 3
285
286
simple_graph_results: List[model.VulnerabilityID] = []
287
simple_graph_results = actions_on_simple_environment.list_vulnerabilities_in_target("dc")
288
assert len(simple_graph_results) == 3
289
290
291
def test_exploit_remote_vulnerability(actions_on_simple_environment: Fixture) -> None:
292
"""
293
This function will test the exploit_remote_vulnerability function from the
294
AgentActions class in actions.py
295
"""
296
297
actions_on_simple_environment.exploit_local_vulnerability("a", "ListNeighbors")
298
299
# test with invalid source node
300
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
301
actions_on_simple_environment.exploit_remote_vulnerability("z", "b", "RDPBF")
302
303
# test with invalid destination node
304
with pytest.raises(ValueError, match=r"invalid target node id '.*'"):
305
actions_on_simple_environment.exploit_remote_vulnerability("a", "z", "RDPBF")
306
307
# test with a local vulnerability
308
with pytest.raises(ValueError, match=r"vulnerability id '.*' is for an attack of type .*"):
309
actions_on_simple_environment.exploit_remote_vulnerability("a", "c", "MimikatzLogonpasswords")
310
311
# test with an invalid vulnerability (one not there)
312
result = actions_on_simple_environment.exploit_remote_vulnerability("a", "c", "HackTheGibson")
313
assert result.outcome is None and result.reward <= 0
314
315
# add RDP brute force to the target node
316
# very hacky not to be used normally.
317
graph: nx.graph.Graph = actions_on_simple_environment._environment.network
318
node: model.NodeInfo = graph.nodes["c"]["data"]
319
node.vulnerabilities = SAMPLE_VULNERABILITIES
320
321
# test a valid and functional one.
322
result = actions_on_simple_environment.exploit_remote_vulnerability("a", "c", "RDPBF")
323
assert isinstance(result.outcome, model.LateralMove)
324
assert result.reward < node.value
325
326
327
def test_exploit_local_vulnerability(actions_on_simple_environment: Fixture) -> None:
328
"""
329
This function will test the exploit_local_vulnerability function from the
330
AgentActions class in actions.py
331
"""
332
333
# check one with invalid prerequisites
334
result: actions.ActionResult = actions_on_simple_environment.exploit_local_vulnerability("a", "MimikatzLogonpasswords")
335
assert isinstance(result.outcome, model.ExploitFailed)
336
337
# test admin privilege escalation
338
# exploit_local_vulnerability(node_id, vulnerability_id)
339
result = actions_on_simple_environment.exploit_local_vulnerability("a", "UACME61")
340
assert isinstance(result.outcome, model.AdminEscalation)
341
node: model.NodeInfo = actions_on_simple_environment._environment.network.nodes["a"]["data"]
342
assert model.AdminEscalation().tag in node.properties
343
344
# test system privilege escalation
345
result = actions_on_simple_environment.exploit_local_vulnerability("c", "UACME67")
346
assert isinstance(result.outcome, model.SystemEscalation)
347
node = actions_on_simple_environment._environment.network.nodes["c"]["data"]
348
assert model.SystemEscalation().tag in node.properties
349
350
# test dump credentials
351
result = actions_on_simple_environment.exploit_local_vulnerability("a", "MimikatzLogonpasswords")
352
assert isinstance(result.outcome, model.LeakedCredentials)
353
354
355
def test_connect_to_remote_machine(
356
actions_on_empty_environment: Fixture,
357
actions_on_single_node_environment: Fixture,
358
actions_on_simple_environment: Fixture,
359
) -> None:
360
"""
361
This function will test the connect_to_remote_machine function from the
362
AgentActions class in actions.py
363
"""
364
actions_on_simple_environment.exploit_local_vulnerability("a", "ListNeighbors")
365
actions_on_simple_environment.exploit_local_vulnerability("a", "DumpCreds")
366
367
# test connect to remote machine on an empty environment
368
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
369
actions_on_empty_environment.connect_to_remote_machine("a", "b", "RDP", "cred")
370
371
# test connect to remote machine on an environment with 1 node
372
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
373
actions_on_single_node_environment.connect_to_remote_machine("a", "b", "RDP", "cred")
374
375
graph: nx.graph.Graph = actions_on_simple_environment._environment.network
376
377
# test connect to remote machine on an environment with multiple nodes
378
# test with valid source node and invalid destination node
379
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
380
actions_on_simple_environment.connect_to_remote_machine("a", "f", "RDP", "cred")
381
382
# test with an invalid source node and valid destination node
383
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
384
actions_on_simple_environment.connect_to_remote_machine("f", "dc", "RDP", "cred")
385
386
# test with both nodes invalid
387
with pytest.raises(ValueError, match=r"invalid node id '.*'"):
388
actions_on_simple_environment.connect_to_remote_machine("f", "z", "RDP", "cred")
389
390
# test with invalid protocol
391
result = actions_on_simple_environment.connect_to_remote_machine("a", "dc", "TCPIP", "cred")
392
assert result.reward <= 0 and result.outcome is None
393
394
# test with invalid credentials
395
result2 = actions_on_simple_environment.connect_to_remote_machine("a", "dc", "RDP", "cred")
396
assert result2.outcome is None and result2.reward <= 0
397
398
# test blocking firewall rule
399
ret_val = actions_on_simple_environment.connect_to_remote_machine("a", "Sharepoint", "RDP", "ADPrincipalCreds")
400
assert ret_val.reward < 0
401
402
# test with valid nodes
403
ret_val = actions_on_simple_environment.connect_to_remote_machine("a", "Sharepoint", "HTTPS", "ADPrincipalCreds")
404
405
assert ret_val.reward == 100
406
407
assert graph.has_edge("a", "dc")
408
409
410
def test_check_prerequisites(actions_on_simple_environment: Fixture) -> None:
411
"""
412
This function will test the _checkPrerequisites function
413
It's marked as a private function but still needs to be tested before use
414
415
"""
416
# testing on a node/vuln combo which should give us a negative result
417
result = actions_on_simple_environment._check_prerequisites("dc", SAMPLE_VULNERABILITIES["MimikatzLogonpasswords"])
418
assert not result
419
420
# testing on a node/vuln combo which should give us a positive reuslt.
421
result = actions_on_simple_environment._check_prerequisites("dc", SAMPLE_VULNERABILITIES["UACME61"])
422
assert result
423
424