Path: blob/main/cyberbattle/_env/graph_wrapper.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23from typing import Union, Tuple45import gymnasium as gym6import numpy as onp7import networkx as nx89from .graph_spaces import DiGraph101112Action = Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]131415class CyberBattleGraph(gym.Wrapper):16"""1718A wrapper for CyberBattleSim that maintains the agent's19knowledge graph containing information about the subset20of the network that was explored so far.2122Currently the nodes of this graph are a subset of the environment nodes.23Eventually we will add new node types to represent various entities24like credentials and users. Edges will represent relationships between those entities25(e.g. user X is authenticated with machine Y using credential Z).2627Actions28-------2930Actions are of the form:3132.. code:: python3334(kind, *indicators)353637The ``kind`` which is one of3839.. code:: python4041# kind420: Local Vulnerability431: Remote Vulnerability442: Connect4546The indicators vary in meaning and length, depending on the ``kind``:4748.. code:: python4950# kind=0 (Local Vulnerability)51indicators = (node_id, local_vulnerability_id)5253# kind=1 (Remote Vulnerability)54indicators = (from_node_id, to_node_id, remote_vulnerability_id)5556# kind=2 (Connect)57indicators = (from_node_id, to_node_id, port_id, credential_id)5859The node ids can be obtained from the graph, e.g.6061.. code:: python6263node_ids = observation['graph'].keys()6465The other indicators are listed below.6667.. code:: python6869# local_vulnerability_ids700: ScanBashHistory711: ScanExplorerRecentFiles722: SudoAttempt733: CrackKeepPassX744: CrackKeepPass7576# remote_vulnerability_ids770: ProbeLinux781: ProbeWindows7980# port_ids810: HTTPS821: GIT832: SSH843: RDP854: PING865: MySQL876: SSH-key887: su8990Examples91~~~~~~~~92Here are some example actions:9394.. code:: python9596a = (0, 5, 3) # try local vulnerability "CrackKeepPassX" on node 597a = (1, 5, 7, 1) # try remote vulnerability "ProbeWindows" from node 5 to node 798a = (2, 5, 7, 3, 2) # try to connect from node 5 to node 7 using credential 2 over RDP port99100101Observations102------------103104Observations are graphs of the nodes that have been discovered so far. Each node is annotated105with a dict of properties of the form:106107.. code:: python108109node_properties = {110'name': 'FooServer', # human-readable identifier111'privilege_level': 1, # 0: not owned, 1: admin, 2: system112'flags': array(-1, 0, 1, 0, 0, ..., 0]), # 1: set, -1: unset, 0: unknown113'credentials': array([-1, 5, -1, ..., -1]), # array of ports (-1 means no cred)114'has_leaked_creds': True, # whether node has leaked any credentials so far115}116117# flag_ids1180: Windows1191: Linux1202: ApacheWebSite1213: IIS_20191224: IIS_2020_patched1235: MySql1246: Ubuntu1257: nginx/1.10.31268: SMB_vuln1279: SMB_vuln_patched12810: SQLServer12911: Win1013012: Win10Patched13113: FLAG:Linux132133Note that the **position** of a non-trivial port number in ``'credentials'`` corresponds to the134credential id. Therefore, for the node in the example above, we have a known credential on135:code:`port_id=5` with :code:`credential_id=1` (the position in the array).136137138"""139140__kinds = ("local_vulnerability", "remote_vulnerability", "connect")141142def __init__(self, env, maximum_total_credentials=22, maximum_node_count=22):143super().__init__(env)144self._bounds = env.bounds145self.__graph = nx.DiGraph()146self.observation_space = DiGraph(self._bounds.maximum_node_count)147148def reset(self, **kwargs):149observation = self.env.reset(**kwargs)150self.__graph = nx.DiGraph()151self.__add_node(observation)152self.__update_nodes(observation)153info = {}154return self.__graph, info155156def step(self, action: Action):157"""158159Take a step in the MDP.160161Args:162action: An *abstract* action.163164Returns:165observation: The next-step observation.166reward: The reward associated with the given action (and previous observation).167done: Whether the next-step observation is a terminal state.168info: Some additional info.169170"""171kind_id, *indicators = action172observation, reward, done, truncated, info = self.env.step({self.__kinds[kind_id]: indicators})173for _ in range(observation["newly_discovered_nodes_count"]):174self.__add_node(observation)175if True: # TODO: do we need to update edges and nodes every time?176self.__update_edges(observation)177self.__update_nodes(observation)178return self.__graph, reward, done, truncated, info179180def __add_node(self, observation):181while self.__graph.number_of_nodes() < observation["discovered_node_count"]:182node_index = self.__graph.number_of_nodes()183creds = onp.full(self._bounds.maximum_total_credentials, -1, dtype=onp.int8)184self.__graph.add_node(185node_index,186name=observation["_discovered_nodes"][node_index],187privilege_level=None,188flags=None, # these are set by __update_nodes()189credentials=creds,190has_leaked_creds=False,191)192193def __update_edges(self, observation):194g_orig = observation["_explored_network"]195node_ids = {n: i for i, n in enumerate(observation["_discovered_nodes"])}196for (from_name, to_name), edge_properties in g_orig.edges.items():197self.__graph.add_edge(node_ids[from_name], node_ids[to_name], **edge_properties)198199def __update_nodes(self, observation):200node_properties = zip(201observation["nodes_privilegelevel"],202observation["discovered_nodes_properties"],203)204for node_id, (privilege_level, flags) in enumerate(node_properties):205# This value is already provided in self.__graph.nodes[node_id]['data'].privilege_level206self.__graph.nodes[node_id]["privilege_level"] = privilege_level207# This value is already provided in self.__graph.nodes[node_id]['data'].properties208self.__graph.nodes[node_id]["flags"] = flags209210for cred_id, (node_id, port_id) in enumerate(observation["credential_cache_matrix"]):211node_id, port_id = int(node_id), int(port_id)212# NOTE: this code ignores situations where the same cred_id is213# used for two different ports (This can be the case, even on the same node for two different ports.)214self.__graph.nodes[node_id]["credentials"][cred_id] = port_id215# Mark the node has leaking credentials216self.__graph.nodes[node_id]["has_leaked_creds"] = True217218219