Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/graph_wrapper.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
from typing import Union, Tuple
5
6
import gymnasium as gym
7
import numpy as onp
8
import networkx as nx
9
10
from .graph_spaces import DiGraph
11
12
13
Action = Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
14
15
16
class CyberBattleGraph(gym.Wrapper):
17
"""
18
19
A wrapper for CyberBattleSim that maintains the agent's
20
knowledge graph containing information about the subset
21
of the network that was explored so far.
22
23
Currently the nodes of this graph are a subset of the environment nodes.
24
Eventually we will add new node types to represent various entities
25
like credentials and users. Edges will represent relationships between those entities
26
(e.g. user X is authenticated with machine Y using credential Z).
27
28
Actions
29
-------
30
31
Actions are of the form:
32
33
.. code:: python
34
35
(kind, *indicators)
36
37
38
The ``kind`` which is one of
39
40
.. code:: python
41
42
# kind
43
0: Local Vulnerability
44
1: Remote Vulnerability
45
2: Connect
46
47
The indicators vary in meaning and length, depending on the ``kind``:
48
49
.. code:: python
50
51
# kind=0 (Local Vulnerability)
52
indicators = (node_id, local_vulnerability_id)
53
54
# kind=1 (Remote Vulnerability)
55
indicators = (from_node_id, to_node_id, remote_vulnerability_id)
56
57
# kind=2 (Connect)
58
indicators = (from_node_id, to_node_id, port_id, credential_id)
59
60
The node ids can be obtained from the graph, e.g.
61
62
.. code:: python
63
64
node_ids = observation['graph'].keys()
65
66
The other indicators are listed below.
67
68
.. code:: python
69
70
# local_vulnerability_ids
71
0: ScanBashHistory
72
1: ScanExplorerRecentFiles
73
2: SudoAttempt
74
3: CrackKeepPassX
75
4: CrackKeepPass
76
77
# remote_vulnerability_ids
78
0: ProbeLinux
79
1: ProbeWindows
80
81
# port_ids
82
0: HTTPS
83
1: GIT
84
2: SSH
85
3: RDP
86
4: PING
87
5: MySQL
88
6: SSH-key
89
7: su
90
91
Examples
92
~~~~~~~~
93
Here are some example actions:
94
95
.. code:: python
96
97
a = (0, 5, 3) # try local vulnerability "CrackKeepPassX" on node 5
98
a = (1, 5, 7, 1) # try remote vulnerability "ProbeWindows" from node 5 to node 7
99
a = (2, 5, 7, 3, 2) # try to connect from node 5 to node 7 using credential 2 over RDP port
100
101
102
Observations
103
------------
104
105
Observations are graphs of the nodes that have been discovered so far. Each node is annotated
106
with a dict of properties of the form:
107
108
.. code:: python
109
110
node_properties = {
111
'name': 'FooServer', # human-readable identifier
112
'privilege_level': 1, # 0: not owned, 1: admin, 2: system
113
'flags': array(-1, 0, 1, 0, 0, ..., 0]), # 1: set, -1: unset, 0: unknown
114
'credentials': array([-1, 5, -1, ..., -1]), # array of ports (-1 means no cred)
115
'has_leaked_creds': True, # whether node has leaked any credentials so far
116
}
117
118
# flag_ids
119
0: Windows
120
1: Linux
121
2: ApacheWebSite
122
3: IIS_2019
123
4: IIS_2020_patched
124
5: MySql
125
6: Ubuntu
126
7: nginx/1.10.3
127
8: SMB_vuln
128
9: SMB_vuln_patched
129
10: SQLServer
130
11: Win10
131
12: Win10Patched
132
13: FLAG:Linux
133
134
Note that the **position** of a non-trivial port number in ``'credentials'`` corresponds to the
135
credential id. Therefore, for the node in the example above, we have a known credential on
136
:code:`port_id=5` with :code:`credential_id=1` (the position in the array).
137
138
139
"""
140
141
__kinds = ("local_vulnerability", "remote_vulnerability", "connect")
142
143
def __init__(self, env, maximum_total_credentials=22, maximum_node_count=22):
144
super().__init__(env)
145
self._bounds = env.bounds
146
self.__graph = nx.DiGraph()
147
self.observation_space = DiGraph(self._bounds.maximum_node_count)
148
149
def reset(self, **kwargs):
150
observation = self.env.reset(**kwargs)
151
self.__graph = nx.DiGraph()
152
self.__add_node(observation)
153
self.__update_nodes(observation)
154
info = {}
155
return self.__graph, info
156
157
def step(self, action: Action):
158
"""
159
160
Take a step in the MDP.
161
162
Args:
163
action: An *abstract* action.
164
165
Returns:
166
observation: The next-step observation.
167
reward: The reward associated with the given action (and previous observation).
168
done: Whether the next-step observation is a terminal state.
169
info: Some additional info.
170
171
"""
172
kind_id, *indicators = action
173
observation, reward, done, truncated, info = self.env.step({self.__kinds[kind_id]: indicators})
174
for _ in range(observation["newly_discovered_nodes_count"]):
175
self.__add_node(observation)
176
if True: # TODO: do we need to update edges and nodes every time?
177
self.__update_edges(observation)
178
self.__update_nodes(observation)
179
return self.__graph, reward, done, truncated, info
180
181
def __add_node(self, observation):
182
while self.__graph.number_of_nodes() < observation["discovered_node_count"]:
183
node_index = self.__graph.number_of_nodes()
184
creds = onp.full(self._bounds.maximum_total_credentials, -1, dtype=onp.int8)
185
self.__graph.add_node(
186
node_index,
187
name=observation["_discovered_nodes"][node_index],
188
privilege_level=None,
189
flags=None, # these are set by __update_nodes()
190
credentials=creds,
191
has_leaked_creds=False,
192
)
193
194
def __update_edges(self, observation):
195
g_orig = observation["_explored_network"]
196
node_ids = {n: i for i, n in enumerate(observation["_discovered_nodes"])}
197
for (from_name, to_name), edge_properties in g_orig.edges.items():
198
self.__graph.add_edge(node_ids[from_name], node_ids[to_name], **edge_properties)
199
200
def __update_nodes(self, observation):
201
node_properties = zip(
202
observation["nodes_privilegelevel"],
203
observation["discovered_nodes_properties"],
204
)
205
for node_id, (privilege_level, flags) in enumerate(node_properties):
206
# This value is already provided in self.__graph.nodes[node_id]['data'].privilege_level
207
self.__graph.nodes[node_id]["privilege_level"] = privilege_level
208
# This value is already provided in self.__graph.nodes[node_id]['data'].properties
209
self.__graph.nodes[node_id]["flags"] = flags
210
211
for cred_id, (node_id, port_id) in enumerate(observation["credential_cache_matrix"]):
212
node_id, port_id = int(node_id), int(port_id)
213
# NOTE: this code ignores situations where the same cred_id is
214
# used for two different ports (This can be the case, even on the same node for two different ports.)
215
self.__graph.nodes[node_id]["credentials"][cred_id] = port_id
216
# Mark the node has leaking credentials
217
self.__graph.nodes[node_id]["has_leaked_creds"] = True
218
219