Path: blob/main/cyberbattle/_env/cyberbattle_env.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Anatares OpenGym Environment"""45import time6import copy7import logging8import networkx9from networkx import convert_matrix10from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict, cast1112from gymnasium import spaces, Env13from gymnasium.utils import seeding1415import numpy1617from plotly.graph_objects import Scatter # type: ignore18from plotly.subplots import make_subplots # type: ignore1920from cyberbattle._env.defender import DefenderAgent21from cyberbattle.simulation.model import PortName, PrivilegeLevel22from ..simulation import commandcontrol, model, actions23from .discriminatedunion import DiscriminatedUnion24import numpy as np2526LOGGER = logging.getLogger(__name__)2728# Used to allocate a discrete space value representing a field that29# is 'Not Applicable' (of value 0 by convention)30NA = 13132# Value defining an unused space slot33UNUSED_SLOT = numpy.int32(0)34# Value defining a used space slot35USED_SLOT = numpy.int32(1)363738# The type of a sample from the Action space39Action = TypedDict(40"Action",41{42"local_vulnerability": numpy.ndarray,43# adding the generic type causes runtime44# TypeError `'type' object is not subscriptable'`45"remote_vulnerability": numpy.ndarray,46"connect": numpy.ndarray,47},48total=False,49)5051# Type of a sample from the ActionMask space52ActionMask = TypedDict(53"ActionMask",54{55"local_vulnerability": numpy.ndarray,56"remote_vulnerability": numpy.ndarray,57"connect": numpy.ndarray,58},59)6061# Type of a sample from the Observation space62Observation = TypedDict(63"Observation",64{65# ---------------------------------------------------------66# Outcome of the action just executed67# ---------------------------------------------------------68# number of new nodes discovered69"newly_discovered_nodes_count": numpy.int32,70# whether a lateral move was just performed71"lateral_move": numpy.int32,72# whether customer data were just discovered73"customer_data_found": numpy.int32,74# 0 if there were no probing attempt75# 1 if an attempted probing failed76# 2 if an attempted probing succeeded77"probe_result": numpy.int32,78# whether an escalation was completed and to which level79"escalation": numpy.int32,80# credentials that were just discovered after executing an action81"leaked_credentials": Tuple[numpy.ndarray, ...], # type: ignore82# bitmask indicating which action are valid in the current state83"action_mask": ActionMask,84# ---------------------------------------------------------85# State information aggregated over all actions executed so far86# ---------------------------------------------------------87# size of the credential stack (number of tuples in `credential_cache_matrix` that are not zeros)88"credential_cache_length": int,89# total nodes discovered so far90"discovered_node_count": int,91# Matrix of properties for all the discovered nodes92"discovered_nodes_properties": numpy.ndarray,93# Node privilege level on every discovered node (e.g., 0 if not owned, 1 owned, 2 admin, 3 for system)94"nodes_privilegelevel": numpy.ndarray,95# Tuple encoding of the credential cache matrix.96# It consists of `bounds.maximum_total_credentials` tuples97# of numpy array of shape (2)98# where only the first `credential_cache_length` tuples are populated.99#100# Each tuple represent a discovered credential,101# the credential index is given by its tuple index (i.e., order of discovery)102# Each tuple is of the form: (target_node_discover_index, port_index)103"credential_cache_matrix": Tuple[numpy.ndarray, ...],104# ---------------------------------------------------------105# Raw information fields coming from the simulation environment106# that are not encoded as gym spaces (were previously in the 'info' field)107# ---------------------------------------------------------108# Mapping node index to internal IDs of all nodes discovered so far.109# The external node index used by the agent to refer to a node110# is defined as the index of the node in this array111"_discovered_nodes": List[model.NodeID],112# The subgraph of nodes discovered so far with annotated edges113# representing interactions that took place during the simulation. (See114# actions.EdgeAnnotation)115"_explored_network": networkx.DiGraph,116},117)118119120# Information returned to gym by the step function121StepInfo = TypedDict(122"StepInfo",123{124"description": str,125"duration_in_ms": float,126"step_count": int,127"network_availability": float,128# internal IDs of the credentials in the cache129"credential_cache": List[model.CachedCredential],130},131)132133134class OutOfBoundIndexError(Exception):135"""The agent attempted to reference an entity (node or a vulnerability) with an invalid index"""136137138Key = TypeVar("Key")139Value = TypeVar("Value")140141142def inverse_dict(self: Dict[Key, Value]) -> Dict[Value, Key]:143"""Inverse a dictionary"""144return {v: k for k, v in self.items()}145146147class DummySpace(spaces.Space):148"""This class ensures that the values in the gym.spaces.Dict space are derived from gymnasium.Space"""149150def __init__(self, sample: object):151self._sample = sample152153def contains(self, x: object) -> bool:154return True155156def sample(self, mask=None) -> object:157return self._sample158159160def sourcenode_of_action(x: Action) -> int:161"""Return the source node of a given action"""162if "local_vulnerability" in x:163return x["local_vulnerability"][0]164elif "remote_vulnerability" in x:165return x["remote_vulnerability"][0]166167assert "connect" in x168return x["connect"][0]169170171class EnvironmentBounds(NamedTuple):172"""Define global bounds posisibly shared by a set of CyberBattle gym environments173174maximum_node_count - Maximum number of nodes in a given network175maximum_total_credentials - Maximum number of credentials in a given network176maximum_discoverable_credentials_per_action - Maximum number of credentials177that can be returned at a time by any action178179port_count - Unique protocol ports180property_count - Unique node property names181local_attacks_count - Unique local vulnerabilities182remote_attacks_count - Unique remote vulnerabilities183"""184185maximum_total_credentials: np.int32186maximum_node_count: np.int32187maximum_discoverable_credentials_per_action: np.int32188189port_count: np.int32190property_count: np.int32191local_attacks_count: np.int32192remote_attacks_count: np.int32193194@classmethod195def of_identifiers(196cls,197identifiers: model.Identifiers,198maximum_total_credentials: int,199maximum_node_count: int,200maximum_discoverable_credentials_per_action: Optional[int] = None,201):202203maximum_discoverable_credentials_per_action = maximum_discoverable_credentials_per_action or maximum_total_credentials204205assert np.can_cast(maximum_total_credentials, np.int32), "maximum_total_credentials must be a 32-bit integer"206assert np.can_cast(maximum_node_count, np.int32), "maximum_node_count must be a 32-bit integer"207assert maximum_total_credentials > 0, "maximum_total_credentials must be positive"208assert maximum_node_count > 0, "maximum_node_count must be positive"209assert np.can_cast(len(identifiers.ports), np.int32), "port_count must be a 32-bit integer"210assert np.can_cast(len(identifiers.properties), np.int32), "property_count must be a 32-bit integer"211assert np.can_cast(len(identifiers.local_vulnerabilities), np.int32), "local_attacks_count must be a 32-bit integer"212assert np.can_cast(len(identifiers.remote_vulnerabilities), np.int32), "remote_attacks_count must be a 32-bit integer"213assert np.can_cast(maximum_discoverable_credentials_per_action, np.int32), "maximum_discoverable_credentials_per_action must be a 32-bit integer"214215return EnvironmentBounds(216maximum_total_credentials=np.int32(maximum_total_credentials),217maximum_node_count=np.int32(maximum_node_count),218maximum_discoverable_credentials_per_action=np.int32(maximum_discoverable_credentials_per_action),219port_count=np.int32(len(identifiers.ports)),220property_count=np.int32(len(identifiers.properties)),221local_attacks_count=np.int32(len(identifiers.local_vulnerabilities)),222remote_attacks_count=np.int32(len(identifiers.remote_vulnerabilities)),223)224225226class AttackerGoal(NamedTuple):227"""Define conditions to be simultanesouly met for the attacker to win.228If field values are not specified the default is to target full ownership229of the network nodes.230"""231232# Include goal to reach at least the specifed cumulative total reward after233reward: float = 0.0234# Include goal to bring the availability to lower that the specified SLA value235low_availability: float = 1.0236# Include goal to own at least the specified number of nodes.237own_atleast: int = 0238# Include goal to own at least the specified percentage of the network nodes.239# Set to 1.0 to define goal as the ownership of all network nodes.240own_atleast_percent: float = 1.0241242243class DefenderGoal(NamedTuple):244"""Define conditions to be simultanesouly met for the defender to win."""245246# Met if attacker is evicted from all the network nodes247eviction: bool248249250class DefenderConstraint(NamedTuple):251"""Define constraints to be maintained by the defender at all time."""252253maintain_sla: float254255256class ObservationSpaceType(spaces.Dict):257def __init__(self, bounds: EnvironmentBounds):258super().__init__(259{260# how many new nodes were discovered261"newly_discovered_nodes_count": spaces.Discrete(NA + bounds.maximum_node_count),262# successuflly moved to the target node (1) or not (0)263"lateral_move": spaces.Discrete(2),264# boolean: 1 if customer secret data were discovered, 0 otherwise265"customer_data_found": spaces.Discrete(2),266# whether an attempted probing succeeded or not267"probe_result": spaces.Discrete(3),268# Esclation result269"escalation": spaces.Discrete(model.PrivilegeLevel.MAXIMUM + 1),270# Array of slots describing credentials that were leaked271"leaked_credentials": spaces.Tuple(272# the 1st component indicates if the slot is used or not (SLOT_USED or SLOT_UNSUED)273# the 2nd component gives the credential unique index (external identifier exposed to the agent)274# the 3rd component gives the target node ID275# the 4th component gives the port number276#277# The actual credential secret is not returned by the environment.278# To use the credential as a parameter to another action the agent should refer to it by its index279# e.g. (UNUSED_SLOT,_,_,_) encodes an empty slot280# (USED_SLOT,1,56,22) encodes a leaked credential identified by its index 1,281# that was used to authenticat to target node 56 on port number 22 (e.g. SSH)282[283spaces.MultiDiscrete(284np.array([285NA + 1,286bounds.maximum_total_credentials,287bounds.maximum_node_count,288bounds.port_count,289], dtype=np.int32)290)291]292* bounds.maximum_discoverable_credentials_per_action293),294# Boolean bitmasks defining the subset of valid actions in the current state.295# (1 for valid, 0 for invalid). Note: a valid action is not necessariliy guaranteed to succeed.296# For instance it is a valid action to attempt to connect to a remote node with incorrect credentials,297# even though such action would 'fail' and potentially yield a negative reward.298"action_mask": spaces.Dict(299{300"local_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.local_attacks_count])),301"remote_vulnerability": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.remote_attacks_count])),302"connect": spaces.MultiBinary(np.array([bounds.maximum_node_count, bounds.maximum_node_count, bounds.port_count, bounds.maximum_total_credentials], dtype=np.int32))303}304),305# size of the credential stack306"credential_cache_length": spaces.Discrete(bounds.maximum_total_credentials),307# total nodes discovered so far308"discovered_node_count": spaces.Discrete(bounds.maximum_node_count),309# Matrix of properties for all the discovered nodes310# 3 values for each matrix cell: set, unset, unknown311"discovered_nodes_properties": spaces.MultiDiscrete(np.full(shape=(bounds.maximum_node_count, bounds.property_count), fill_value=3)),312# Escalation level on every discovered node (e.g., 0 if not owned, 1 for admin, 2 for system)313"nodes_privilegelevel": spaces.MultiDiscrete([CyberBattleEnv.privilege_levels] * bounds.maximum_node_count),314# Encoding of the credential cache of shape: (credential_cache_length, 2)315#316# Each row represent a discovered credential,317# the credential index is given by the row index (i.e. order of discovery)318# A row is of the form: (target_node_discover_index, port_index)319"credential_cache_matrix": spaces.Tuple([spaces.MultiDiscrete(np.array([bounds.maximum_node_count, bounds.port_count],dtype=np.int32))] * bounds.maximum_total_credentials),320# ---------------------------------------------------------321# Fields that were previously in the 'info' dict:322# ---------------------------------------------------------323# internal IDs of nodes discovered so far324"_discovered_nodes": DummySpace(sample=["node1", "node0", "node2"]),325# The subgraph of nodes discovered so far with annotated edges326# representing interactions that took place during the simulation. (See327# actions.EdgeAnnotation)328"_explored_network": DummySpace(sample=networkx.DiGraph()),329}330)331332333class CyberBattleSpaceKind(Env[Observation, Action]):334action_space: DiscriminatedUnion # type: ignore335observation_space: ObservationSpaceType # type: ignore336337338class CyberBattleEnv(CyberBattleSpaceKind):339"""OpenAI Gym environment interface to the CyberBattle simulation.340341# Actions342343Run a local attack: `(source_node x local_vulnerability_to_exploit)`344Run a remote attack command: `(source_node x target_node x remote_vulnerability_to_exploit)`345Connect to a remote node: `(source_node x target_node x target_port x credential_index_from_cache)`346347# Observation348349See type `Observation` for a full description of the observation space.350It includes:351- How many new nodes were discovered352- Whether lateral move succeeded353- Whether customer data were found354- Whehter escalation attempt succeeded355- Matrix of all node properties discovered so far356- List of leaked credentials357358# Information359- Action mask indicating the subset of valid actions at the current state360361# Termination362363The simulation ends if either the attacker reaches its goal (e.g. full network ownership),364the defender reaches its goal (e.g. full eviction of the attacker)365or if one of the defender's constraints is not met (e.g. SLA).366"""367368metadata = {"render_modes": ["human"]}369370@property371def environment(self) -> model.Environment:372return self.__environment373374def __reset_environment(self) -> None:375self.__environment: model.Environment = copy.deepcopy(self.__initial_environment)376self.__discovered_nodes: List[model.NodeID] = []377self.__owned_nodes_indices_cache: Optional[List[int]] = None378self.__credential_cache: List[model.CachedCredential] = []379self.__episode_rewards: List[float] = []380# The actuator used to execute actions in the simulation environment381self._actuator = actions.AgentActions(382self.__environment,383throws_on_invalid_actions=self.__throws_on_invalid_actions,384)385self._defender_actuator = actions.DefenderAgentActions(self.__environment)386387self.__stepcount = 0388self.__start_time = time.time()389self.__done = False390391for node_id, node_data in self.__environment.nodes():392if node_data.agent_installed:393self.__discovered_nodes.append(node_id)394395@property396def name(self) -> str:397return "CyberBattleEnv"398399@property400def identifiers(self) -> model.Identifiers:401return self.__environment.identifiers402403@property404def bounds(self) -> EnvironmentBounds:405return self.__bounds406407def validate_environment(self, environment: model.Environment):408"""Validate that the size of the network and associated constants fits within409the dimensions bounds set for the CyberBattle gym environment"""410assert environment.identifiers.ports411assert environment.identifiers.properties412assert environment.identifiers.local_vulnerabilities413assert environment.identifiers.remote_vulnerabilities414415node_count = len(environment.network.nodes.items())416if node_count > self.__bounds.maximum_node_count:417raise ValueError(f"Network node count ({node_count}) exceeds " f"the specified limit of {self.__bounds.maximum_node_count}.")418419# Maximum number of credentials that can possibly be returned by any action420effective_maximum_credentials_per_action = max(421[422len(vulnerability.outcome.credentials)423for _, node_info in environment.nodes()424for _, vulnerability in node_info.vulnerabilities.items()425if isinstance(vulnerability.outcome, model.LeakedCredentials)426]427)428429if effective_maximum_credentials_per_action > self.__bounds.maximum_discoverable_credentials_per_action:430raise ValueError(431f"Some action in the environment returns {effective_maximum_credentials_per_action} "432f"credentials which exceeds the maximum number of discoverable credentials "433f"of {self.__bounds.maximum_discoverable_credentials_per_action}"434)435436refeerenced_ports = model.collect_ports_from_environment(environment)437undefined_ports = set(refeerenced_ports).difference(environment.identifiers.ports)438if undefined_ports:439raise ValueError(f"The network has references to undefined port names: {undefined_ports}")440441referenced_properties = model.collect_properties_from_nodes(model.iterate_network_nodes(environment.network))442undefined_properties = set(referenced_properties).difference(environment.identifiers.properties)443if undefined_properties:444raise ValueError(f"The network has references to undefined property names: {undefined_properties}")445446local_vulnerabilities = model.collect_vulnerability_ids_from_nodes_bytype(447environment.nodes(),448environment.vulnerability_library,449model.VulnerabilityType.LOCAL,450)451452undefined_local_vuln = set(local_vulnerabilities).difference(environment.identifiers.local_vulnerabilities)453if undefined_local_vuln:454raise ValueError(f"The network has references to undefined local" f" vulnerability names: {undefined_local_vuln}")455456remote_vulnerabilities = model.collect_vulnerability_ids_from_nodes_bytype(457environment.nodes(),458environment.vulnerability_library,459model.VulnerabilityType.REMOTE,460)461462undefined_remote_vuln = set(remote_vulnerabilities).difference(environment.identifiers.remote_vulnerabilities)463if undefined_remote_vuln:464raise ValueError(f"The network has references to undefined remote" f" vulnerability names: {undefined_remote_vuln}")465466# number of distinct privilege levels467privilege_levels = model.PrivilegeLevel.MAXIMUM + 1468469def __init__(470self,471initial_environment: model.Environment,472maximum_total_credentials: int = 1000,473maximum_node_count: int = 100,474maximum_discoverable_credentials_per_action: int = 5,475defender_agent: Optional[DefenderAgent] = None,476attacker_goal: Optional[AttackerGoal] = AttackerGoal(own_atleast_percent=1.0),477defender_goal=DefenderGoal(eviction=True),478defender_constraint=DefenderConstraint(maintain_sla=0.0),479winning_reward=5000.0,480losing_reward=0.0,481renderer="",482observation_padding=True,483throws_on_invalid_actions=True,484):485"""Arguments486===========487environment - The CyberBattle network simulation environment488maximum_total_credentials - Maximum total number of credentials used in a network489maximum_node_count - Largest possible size of the network490maximum_discoverable_credentials_per_action - Maximum number of credentials returned by a given action491attacker_goal - Target goal for the attacker to win and stop the simulation.492defender_goal - Target goal for the defender to win and stop the simulation.493defender_constraint - Constraint to be maintain by the defender to keep the simulation running.494winning_reward - Reward granted to the attacker if the simulation ends because the attacker's goal is reached.495losing_reward - Reward granted to the attacker if the simulation ends because the Defender's goal is reached.496renderer - the matplotlib renderer (e.g. 'png')497observation_padding - whether to pad all the observation fields to their maximum size. For instance this will pad the credential matrix498to fit in `maximum_node_count` rows. Turn on this flag for gym agent that expects observations of fixed sizes.499must be set to True with gym >=0.26500throws_on_invalid_actions - whether to raise an exception if the step function attempts an invalid action (e.g., running an attack from a node that's not owned)501if set to False a negative reward is returned instead.502"""503504# maximum number of entities in a given environment505self.__bounds = EnvironmentBounds.of_identifiers(506maximum_total_credentials=maximum_total_credentials,507maximum_node_count=maximum_node_count,508maximum_discoverable_credentials_per_action=maximum_discoverable_credentials_per_action,509identifiers=initial_environment.identifiers,510)511512self.validate_environment(initial_environment)513self.__attacker_goal: Optional[AttackerGoal] = attacker_goal514self.__defender_goal: DefenderGoal = defender_goal515self.__defender_constraint: DefenderConstraint = defender_constraint516self.__WINNING_REWARD = winning_reward517self.__LOSING_REWARD = losing_reward518self.__renderer = renderer519self.__observation_padding = observation_padding520self.__throws_on_invalid_actions = throws_on_invalid_actions521522self.viewer = None523524self.__initial_environment: model.Environment = initial_environment525526# number of entities in the environment network527self.__defender_agent = defender_agent528529self.__reset_environment()530531self.__node_count = len(initial_environment.network.nodes.items())532533# The Space object defining the valid actions of an attacker.534local_vulnerabilities_count = self.__bounds.local_attacks_count535remote_vulnerabilities_count = self.__bounds.remote_attacks_count536maximum_node_count_int32 = self.__bounds.maximum_node_count537port_count = self.__bounds.port_count538539action_spaces = {540"local_vulnerability": spaces.MultiDiscrete(541# source_node_id, vulnerability_id542np.array([maximum_node_count_int32, local_vulnerabilities_count], dtype=np.int32)543),544"remote_vulnerability": spaces.MultiDiscrete(545# source_node_id, target_node_id, vulnerability_id546np.array([maximum_node_count_int32, maximum_node_count_int32, remote_vulnerabilities_count], dtype=np.int32)547),548"connect": spaces.MultiDiscrete(549# source_node_id, target_node_id, target_port, credential_id550# (by index of discovery: 0 for initial node, 1 for first discovered node, ...)551np.array([552maximum_node_count_int32,553maximum_node_count_int32,554port_count,555maximum_total_credentials,556], dtype=np.int32)557),558}559560self.action_space = DiscriminatedUnion[Action](cast(dict, action_spaces)) # type: ignore561562self.observation_space = ObservationSpaceType(self.__bounds)563564# reward_range: A tuple corresponding to the min and max possible rewards565self.reward_range = (-float("inf"), float("inf"))566567def __index_to_local_vulnerabilityid(self, vulnerability_index: int) -> model.VulnerabilityID:568"""Return the local vulnerability identifier from its internal encoding index"""569return self.__initial_environment.identifiers.local_vulnerabilities[vulnerability_index]570571def __index_to_remote_vulnerabilityid(self, vulnerability_index: int) -> model.VulnerabilityID:572"""Return the remote vulnerability identifier from its internal encoding index"""573return self.__initial_environment.identifiers.remote_vulnerabilities[vulnerability_index]574575def __index_to_port_name(self, port_index: int) -> model.PortName:576"""Return the port name identifier from its internal encoding index"""577return self.__initial_environment.identifiers.ports[port_index]578579def __portname_to_index(self, port_name: PortName) -> int:580"""Return the internal encoding index of a given port name"""581return self.__initial_environment.identifiers.ports.index(port_name)582583def __internal_node_id_from_external_node_index(self, node_external_index: int) -> model.NodeID:584""" "Return the internal environment node ID corresponding to the specified585external node index that is exposed to the Gym agent5860 -> ID of inital node5871 -> ID of first discovered node588...589590"""591# Ensures that the specified node is known by the agent592if node_external_index < 0:593raise OutOfBoundIndexError(f"Node index must be positive, given {node_external_index}")594595length = len(self.__discovered_nodes)596if node_external_index >= length:597raise OutOfBoundIndexError(f"Node index ({node_external_index}) is invalid; only {length} nodes discovered so far.")598599node_id = self.__discovered_nodes[node_external_index]600return node_id601602def __find_external_index(self, node_id: model.NodeID) -> int:603"""Find the external index associated with the specified node ID"""604return self.__discovered_nodes.index(node_id)605606def __agent_owns_node(self, node_id: model.NodeID) -> bool:607node = self.__environment.get_node(node_id)608pwned: bool = node.agent_installed609return pwned610611def apply_mask(self, action: Action, mask: Optional[ActionMask] = None) -> bool:612"""Apply the action mask to a specific action. Returns true just if the action613is permitted."""614if mask is None:615mask = self.compute_action_mask()616field_name = DiscriminatedUnion.kind(action)617field_mask, coordinates = mask[field_name], action[field_name] # type: ignore618return bool(field_mask[tuple(coordinates)])619620def __get_blank_action_mask(self) -> ActionMask:621"""Return a blank action mask"""622max_node_count = self.bounds.maximum_node_count623local_vulnerabilities_count = self.__bounds.local_attacks_count624remote_vulnerabilities_count = self.__bounds.remote_attacks_count625port_count = self.__bounds.port_count626local = numpy.zeros(shape=(max_node_count, local_vulnerabilities_count), dtype=numpy.int8)627remote = numpy.zeros(628shape=(max_node_count, max_node_count, remote_vulnerabilities_count),629dtype=numpy.int8,630)631connect = numpy.zeros(632shape=(633max_node_count,634max_node_count,635port_count,636self.__bounds.maximum_total_credentials,637),638dtype=numpy.int8,639)640return ActionMask(local_vulnerability=local, remote_vulnerability=remote, connect=connect)641642def __update_action_mask(self, bitmask: ActionMask) -> None:643"""Update an action mask based on the current state"""644local_vulnerabilities_count = self.__bounds.local_attacks_count645remote_vulnerabilities_count = self.__bounds.remote_attacks_count646port_count = self.__bounds.port_count647648# Compute the vulnerability action bitmask649#650# The agent may attempt exploiting vulnerabilities651# from any node that it owns652for source_node_id in self.__discovered_nodes:653if self.__agent_owns_node(source_node_id):654source_index = self.__find_external_index(source_node_id)655656# Local: since the agent owns the node, all its local vulnerabilities are visible to it657for vulnerability_index in range(local_vulnerabilities_count):658vulnerability_id = self.__index_to_local_vulnerabilityid(vulnerability_index)659node_vulnerable = vulnerability_id in self.__environment.vulnerability_library or vulnerability_id in self.__environment.get_node(source_node_id).vulnerabilities660661if node_vulnerable:662bitmask["local_vulnerability"][source_index, vulnerability_index] = 1663664# Remote: Any other node discovered so far is a potential remote target665for target_node_id in self.__discovered_nodes:666target_index = self.__find_external_index(target_node_id)667bitmask["remote_vulnerability"][source_index, target_index, :remote_vulnerabilities_count] = 1668669# the agent may attempt to connect to any port670# and use any credential from its cache (though it's not guaranteed to succeed)671bitmask["connect"][672source_index,673target_index,674:port_count,675: len(self.__credential_cache),676] = 1677678def compute_action_mask(self) -> ActionMask:679"""Compute the action mask for the current state"""680bitmask = self.__get_blank_action_mask()681self.__update_action_mask(bitmask)682return bitmask683684def pretty_print_internal_action(self, action: Action) -> str:685"""Pretty print an action with internal node and vulnerability identifiers"""686assert 1 == len(action.keys())687assert DiscriminatedUnion.kind(action) != ""688if "local_vulnerability" in action:689source_node_index, vulnerability_index = action["local_vulnerability"]690return f"local_vulnerability(`{self.__internal_node_id_from_external_node_index(source_node_index)}, {self.__index_to_local_vulnerabilityid(vulnerability_index)})"691elif "remote_vulnerability" in action:692source_node, target_node, vulnerability_index = action["remote_vulnerability"]693source_node_id = self.__internal_node_id_from_external_node_index(source_node)694target_node_id = self.__internal_node_id_from_external_node_index(target_node)695return f"remote_vulnerability(`{source_node_id}, `{target_node_id}, {self.__index_to_remote_vulnerabilityid(vulnerability_index)})"696elif "connect" in action:697source_node, target_node, port_index, credential_cache_index = action["connect"]698assert credential_cache_index >= 0699if credential_cache_index >= len(self.__credential_cache):700return "connect(invalid)"701source_node_id = self.__internal_node_id_from_external_node_index(source_node)702target_node_id = self.__internal_node_id_from_external_node_index(target_node)703return f"connect(`{source_node_id}, `{target_node_id}, {self.__index_to_port_name(port_index)}, {self.__credential_cache[credential_cache_index].credential})"704raise ValueError("Invalid discriminated union value: " + str(action))705706def __execute_action(self, action: Action) -> actions.ActionResult:707# Assert that the specified action is consistent (i.e., defining a single action type)708assert 1 == len(action.keys())709710assert DiscriminatedUnion.kind(action) != ""711712if "local_vulnerability" in action:713source_node_index, vulnerability_index = action["local_vulnerability"]714715return self._actuator.exploit_local_vulnerability(716self.__internal_node_id_from_external_node_index(source_node_index),717self.__index_to_local_vulnerabilityid(vulnerability_index),718)719720elif "remote_vulnerability" in action:721source_node, target_node, vulnerability_index = action["remote_vulnerability"]722source_node_id = self.__internal_node_id_from_external_node_index(source_node)723target_node_id = self.__internal_node_id_from_external_node_index(target_node)724725result = self._actuator.exploit_remote_vulnerability(726source_node_id,727target_node_id,728self.__index_to_remote_vulnerabilityid(vulnerability_index),729)730731return result732733elif "connect" in action:734source_node, target_node, port_index, credential_cache_index = action["connect"]735if credential_cache_index < 0 or credential_cache_index >= len(self.__credential_cache):736return actions.ActionResult(reward=-1, outcome=None)737738source_node_id = self.__internal_node_id_from_external_node_index(source_node)739target_node_id = self.__internal_node_id_from_external_node_index(target_node)740741result = self._actuator.connect_to_remote_machine(742source_node_id,743target_node_id,744self.__index_to_port_name(port_index),745self.__credential_cache[credential_cache_index].credential,746)747748return result749750raise ValueError("Invalid discriminated union value: " + str(action))751752def __get_blank_observation(self) -> Observation:753observation = Observation(754newly_discovered_nodes_count=numpy.int32(0),755leaked_credentials=tuple([numpy.array([UNUSED_SLOT, 0, 0, 0], dtype=numpy.int32)] * self.__bounds.maximum_discoverable_credentials_per_action),756lateral_move=numpy.int32(0),757customer_data_found=numpy.int32(0),758escalation=numpy.int32(PrivilegeLevel.NoAccess),759action_mask=self.__get_blank_action_mask(),760probe_result=numpy.int32(0),761credential_cache_matrix=tuple([numpy.zeros((2), dtype=numpy.int64)] * self.__bounds.maximum_total_credentials),762credential_cache_length=0,763discovered_node_count=len(self.__discovered_nodes),764discovered_nodes_properties=numpy.full((self.__bounds.maximum_node_count, self.__bounds.property_count,), 2, dtype=numpy.int32),765nodes_privilegelevel=numpy.zeros((self.bounds.maximum_node_count,), dtype=numpy.int32),766# raw data not actually encoded as a proper gym numeric space767# (were previously returned in the 'info' dict)768_discovered_nodes=self.__discovered_nodes,769_explored_network=self.__get_explored_network(),770)771772return observation773774def __pad_array_if_requested(self, o, pad_value, desired_length) -> numpy.ndarray:775"""Pad an array observation with provided padding if the padding option is enabled776for this environment"""777if self.__observation_padding:778padding = numpy.full((desired_length - len(o)), pad_value, dtype=numpy.int32)779return numpy.concatenate((o, padding))780else:781return o782783def __pad_tuple_if_requested(self, o, row_shape, desired_length) -> Tuple[numpy.ndarray, ...]:784"""Pad a tuple observation with provided padding if the padding option is enabled785for this environment"""786if self.__observation_padding:787padding = [numpy.zeros(row_shape, dtype=numpy.int32)] * (desired_length - len(o))788return tuple(o + padding)789else:790return tuple(o)791792def __property_vector(self, node_id: model.NodeID, node_info: model.NodeInfo) -> numpy.ndarray:793"""Property vector for specified node794each cell is either 1 if the property is set, 0 if unset, and 2 if unknown (node is not owned by the agent yet)795"""796properties_indices = list(self._actuator.get_discovered_properties(node_id))797798is_owned = self._actuator.get_node_privilegelevel(node_id) >= PrivilegeLevel.LocalUser799800if is_owned:801# if the node is owned then we know all its properties802vector = numpy.full((self.__bounds.property_count), 0, dtype=numpy.int32)803else:804# otherwise we don't know anything about not discovered properties => 0 should be the default value805vector = numpy.zeros((self.__bounds.property_count), dtype=numpy.int32)806807vector[properties_indices] = 1808return vector809810def __get_property_matrix(self) -> numpy.ndarray:811"""Return the Node-Property matrix,812where 0 means the property is not set for that node8131 means the property is set for that node8142 means the property status is unknown815816e.g.: [ 1 0 0 1 ]8172 2 2 28180 1 0 1 ]8191st row: set and unset properties for the 1st discovered and owned node8202nd row: no known properties for the 2nd discovered node8213rd row: properties of 3rd discovered and owned node"""822property_discovered = [self.__property_vector(node_id, node_info) for node_id, node_info in self._actuator.discovered_nodes()]823as_numpy = numpy.array(self.__pad_tuple_if_requested(824property_discovered,825self.__bounds.property_count,826self.__bounds.maximum_node_count,827))828assert as_numpy.shape == (self.__bounds.maximum_node_count, self.__bounds.property_count)829return as_numpy830831def __get__owned_nodes_indices(self) -> List[int]:832"""Get list of indices of all owned nodes"""833if self.__owned_nodes_indices_cache is None:834owned_nodeids = self._actuator.get_nodes_with_atleast_privilegelevel(PrivilegeLevel.LocalUser)835self.__owned_nodes_indices_cache = [self.__find_external_index(n) for n in owned_nodeids]836837return self.__owned_nodes_indices_cache838839def __get_privilegelevel_array(self) -> numpy.ndarray:840"""Return the node escalation level array,841where 0 means that the node is not owned8421 if the node is owned8432 if the node is owned and escalated to admin8443 if the node is owned and escalated to SYSTEM845... further escalation levels defined by the network846"""847privilegelevel_array = numpy.array(848[int(self._actuator.get_node_privilegelevel(node)) for node in self.__discovered_nodes],849dtype=numpy.int32,850)851852return self.__pad_array_if_requested(853privilegelevel_array,854PrivilegeLevel.NoAccess,855self.__bounds.maximum_node_count,856)857858def __observation_reward_from_action_result(self, result: actions.ActionResult) -> Tuple[Observation, float]:859obs = self.__get_blank_observation()860outcome = result.outcome861862if isinstance(outcome, model.LeakedNodesId):863# update discovered nodes864newly_discovered_nodes_count = 0865for node in outcome.nodes:866if node not in self.__discovered_nodes:867self.__discovered_nodes.append(node)868newly_discovered_nodes_count += 1869870obs["newly_discovered_nodes_count"] = numpy.int32(newly_discovered_nodes_count)871872elif isinstance(outcome, model.LeakedCredentials):873# update discovered nodes and credentials874newly_discovered_nodes_count = 0875newly_discovered_creds: List[Tuple[int, model.CachedCredential]] = []876for cached_credential in outcome.credentials:877if cached_credential.node not in self.__discovered_nodes:878self.__discovered_nodes.append(cached_credential.node)879newly_discovered_nodes_count += 1880881if cached_credential not in self.__credential_cache:882self.__credential_cache.append(cached_credential)883added_credential_index = len(self.__credential_cache) - 1884newly_discovered_creds.append((added_credential_index, cached_credential))885886obs["newly_discovered_nodes_count"] = numpy.int32(newly_discovered_nodes_count)887888# Encode the returned credentials in the format expected by the gym agent889leaked_credentials = [890numpy.array(891[892USED_SLOT,893cache_index,894self.__find_external_index(cached_credential.node),895self.__portname_to_index(cached_credential.port),896],897numpy.int32,898)899for cache_index, cached_credential in newly_discovered_creds900]901902obs["leaked_credentials"] = self.__pad_tuple_if_requested(903leaked_credentials,9044,905self.__bounds.maximum_discoverable_credentials_per_action,906)907908elif isinstance(outcome, model.LateralMove):909obs["lateral_move"] = numpy.int32(1)910elif isinstance(outcome, model.CustomerData):911obs["customer_data_found"] = numpy.int32(1)912elif isinstance(outcome, model.ProbeSucceeded):913obs["probe_result"] = numpy.int32(2)914elif isinstance(outcome, model.ProbeFailed):915obs["probe_result"] = numpy.int32(1)916elif isinstance(outcome, model.PrivilegeEscalation):917obs["escalation"] = numpy.int32(outcome.level)918919cache = [numpy.array([self.__find_external_index(c.node), self.__portname_to_index(c.port)]) for c in self.__credential_cache]920921obs["credential_cache_matrix"] = self.__pad_tuple_if_requested(cache, 2, self.__bounds.maximum_total_credentials)922923# Dynamic statistics to be refreshed924obs["credential_cache_length"] = len(self.__credential_cache)925obs["discovered_node_count"] = len(self.__discovered_nodes)926obs["discovered_nodes_properties"] = self.__get_property_matrix()927obs["nodes_privilegelevel"] = self.__get_privilegelevel_array()928obs["_discovered_nodes"] = self.__discovered_nodes929obs["_explored_network"] = self.__get_explored_network()930931self.__update_action_mask(obs["action_mask"])932return obs, result.reward933934def sample_connect_action_in_expected_range(self) -> Action:935"""Sample an action of type 'connect' where the parameters936are in the the expected ranges but not necessarily verifying937inter-component constraints.938"""939discovered_credential_count = len(self.__credential_cache)940941if discovered_credential_count <= 0:942raise ValueError("Cannot sample a connect action until the agent discovers more potential target nodes.")943944return Action(945connect=numpy.array(946[947self.np_random.choice(self.__get__owned_nodes_indices()),948self.np_random.integers(0, len(self.__discovered_nodes)),949self.np_random.integers(0, self.__bounds.port_count),950# credential space is sparse so we force sampling951# from the set of credentials that were discovered so far952self.np_random.integers(0, len(self.__credential_cache)),953],954numpy.int32,955)956)957958def sample_action_in_range(self, kinds: Optional[List[int]] = None) -> Action:959"""Sample an action in the expected component ranges but960not necessarily verifying inter-component constraints.961(e.g., may return a local_vulnerability action that is not962supported by the node)963964- kinds -- A list of elements in {0,1,2} indicating what kind of965action to sample (0:local, 1:remote, 2:connect)966"""967968discovered_credential_count = len(self.__credential_cache)969970if kinds is None:971kinds = [0, 1, 2]972973if discovered_credential_count == 0:974# cannot generate a connect action if no cred in the cache975kinds = [t for t in kinds if t != 2]976977assert kinds, "Kinds list cannot be empty"978979choice_random = self.action_space.union_np_random980kind = choice_random.choice(kinds)981982if kind == 2:983action = self.sample_connect_action_in_expected_range()984elif kind == 1:985action = Action(986local_vulnerability=numpy.array(987[988choice_random.choice(self.__get__owned_nodes_indices()),989choice_random.integers(0, self.__bounds.local_attacks_count),990],991numpy.int32,992)993)994else:995action = Action(996remote_vulnerability=numpy.array(997[998choice_random.choice(self.__get__owned_nodes_indices()),999choice_random.integers(0, len(self.__discovered_nodes)),1000choice_random.integers(0, self.__bounds.remote_attacks_count),1001],1002numpy.int32,1003)1004)10051006return action10071008def is_node_owned(self, node: int):1009"""Return true if a discovered node (specified by its external node index)1010is owned by the attacker agent"""1011node_id = self.__internal_node_id_from_external_node_index(node)1012node_owned = self._actuator.get_node_privilegelevel(node_id) > PrivilegeLevel.NoAccess1013return node_owned10141015def is_action_valid(self, action, action_mask: Optional[ActionMask] = None) -> bool:1016"""Determine if an action is valid (i.e. parameters are in expected ranges)"""1017assert 1 == len(action.keys())10181019kind = DiscriminatedUnion.kind(action)1020in_range = False1021n_discovered_nodes = len(self.__discovered_nodes)1022if kind == "local_vulnerability":1023source_node, vulnerability_index = action["local_vulnerability"]1024in_range = source_node < n_discovered_nodes and self.is_node_owned(source_node) and vulnerability_index < self.__bounds.local_attacks_count1025elif kind == "remote_vulnerability":1026source_node, target_node, vulnerability_index = action["remote_vulnerability"]1027in_range = source_node < n_discovered_nodes and self.is_node_owned(source_node) and target_node < n_discovered_nodes and vulnerability_index < self.__bounds.remote_attacks_count1028elif kind == "connect":1029source_node, target_node, port_index, credential_cache_index = action["connect"]1030in_range = (1031source_node < n_discovered_nodes1032and self.is_node_owned(source_node)1033and target_node < n_discovered_nodes1034and port_index < self.__bounds.port_count1035and credential_cache_index < len(self.__credential_cache)1036)10371038return in_range and self.apply_mask(action, action_mask)10391040def sample_valid_action(self, kinds=None) -> Action:1041"""Sample an action within the expected ranges until getting a valid one"""1042action_mask = self.compute_action_mask()1043action = self.sample_action_in_range(kinds)1044while not self.apply_mask(action, action_mask):1045action = self.sample_action_in_range(kinds)1046return action10471048def sample_valid_action_with_luck(self) -> Action:1049"""Sample an action until getting a valid one"""1050action_mask = self.compute_action_mask()1051action = self.action_space.sample()1052while not self.apply_mask(action, action_mask):1053action = self.action_space.sample()1054return action10551056def __get_explored_network(self) -> networkx.DiGraph:1057"""Returns the graph of nodes discovered so far1058with annotated edges representing interactions1059that took place during the simulation.1060"""1061known_nodes = [node_id for node_id, _ in self._actuator.discovered_nodes()]1062network = self.__environment.network1063subgraph = network.subgraph(known_nodes).copy()1064assert isinstance(subgraph, networkx.DiGraph)10651066# hide info for nodes that the agent does not own1067for node_id in subgraph.nodes:1068node_info: Optional[model.NodeInfo] = subgraph.nodes[node_id]["data"]1069if node_info is not None:1070if not node_info.agent_installed:1071subgraph.nodes[node_id]["data"] = None10721073subgraph.nodes[node_id]["privilege_level"] = int(self._actuator.get_node_privilegelevel(node_id))1074subgraph.nodes[node_id]["flags"] = list(self._actuator.get_discovered_properties(node_id))1075subgraph.nodes[node_id]["flags_bits"] = self.__property_vector(node_id, node_info)10761077return subgraph10781079def __attacker_goal_reached(self) -> bool:1080goal = self.__attacker_goal10811082if not goal:1083return False10841085if numpy.sum(self.__episode_rewards) < goal.reward:1086return False10871088nodes_owned = self.__get__owned_nodes_indices()1089owned_count = len(nodes_owned)10901091if owned_count < goal.own_atleast:1092return False10931094if owned_count / self.__node_count < goal.own_atleast_percent:1095return False10961097if self.__defender_agent is not None and self._defender_actuator.network_availability >= goal.low_availability:1098return False10991100return True11011102def __defender_constraints_broken(self) -> bool:1103"""Check if any of the defender's constraint is not met"""1104constraint = self.__defender_constraint11051106if self.__defender_agent is not None and self._defender_actuator.network_availability < constraint.maintain_sla:1107return True11081109return False11101111def __defender_goal_reached(self) -> bool:1112"""Check if defender's goal is reached(e.g. full eviction of attacker)"""1113goal = self.__defender_goal11141115return goal.eviction and not (self.__get__owned_nodes_indices())11161117def get_explored_network_as_numpy(self, observation: Observation) -> numpy.ndarray:1118"""Return the explored network graph adjacency matrix1119as an numpy array of shape (N,N) where1120N is the number of nodes discovered so far"""1121return convert_matrix.to_numpy_array(observation["_explored_network"], weight="kind_as_float")11221123def get_explored_network_node_properties_bitmap_as_numpy(self, observation: Observation) -> numpy.ndarray:1124"""Return a combined the matrix of adjacencies (left part) and1125node properties bitmap (right part).1126Suppose N is the number of discovered nodes and1127P is the total number of properties then1128Then the return matrix is of the form:11291130^ <---- N -----><------ P ------>1131| ( | )1132N ( Adjacency | Node-Properties )1133| ( Matrix | Bitmap )1134V ( | )11351136"""1137return numpy.block(1138[1139convert_matrix.to_numpy_array(observation["_explored_network"], weight="kind_as_float"),1140numpy.array(observation["discovered_nodes_properties"]),1141]1142)11431144def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: # type: ignore1145if self.__done:1146raise RuntimeError("new episode must be started with env.reset()")11471148self.__stepcount += 11149duration = time.time() - self.__start_time1150try:1151result = self.__execute_action(action)1152observation, reward = self.__observation_reward_from_action_result(result)11531154# Execute the defender step if provided1155if self.__defender_agent:1156self._defender_actuator.on_attacker_step_taken()1157self.__defender_agent.step(self.__environment, self._defender_actuator, self.__stepcount)11581159self.__owned_nodes_indices_cache = None11601161if self.__attacker_goal_reached() or self.__defender_constraints_broken():1162self.__done = True1163reward = self.__WINNING_REWARD1164elif self.__defender_goal_reached():1165self.__done = True1166reward = self.__LOSING_REWARD1167else:1168reward = max(0.0, reward)11691170except OutOfBoundIndexError as error:1171logging.warning("Invalid entity index: " + error.__str__())1172observation = self.__get_blank_observation()1173reward = 0.011741175info = StepInfo(1176description="CyberBattle simulation",1177duration_in_ms=duration,1178step_count=self.__stepcount,1179network_availability=self._defender_actuator.network_availability,1180credential_cache=self.__credential_cache,1181)1182self.__episode_rewards.append(reward)11831184return observation, reward, self.__done, False, info11851186def reset(1187self,1188*,1189seed: Optional[int] = None,1190options: Optional[dict] = None,1191) -> Tuple[Observation, StepInfo]:1192LOGGER.info("Resetting the CyberBattle environment")1193self.__reset_environment()1194self.np_random, seed = seeding.np_random(seed)11951196observation = self.__get_blank_observation()1197observation["action_mask"] = self.compute_action_mask()1198observation["discovered_nodes_properties"] = self.__get_property_matrix()1199observation["nodes_privilegelevel"] = self.__get_privilegelevel_array()1200self.__owned_nodes_indices_cache = None1201info = StepInfo(1202description="CyberBattle simulation",1203duration_in_ms=0,1204step_count=self.__stepcount,1205network_availability=self._defender_actuator.network_availability,1206credential_cache=self.__credential_cache,1207)1208return observation, info12091210def render_as_fig(self):1211debug = commandcontrol.EnvironmentDebugging(self._actuator)1212self._actuator.print_all_attacks()12131214# plot the cumulative reward and network side by side using plotly1215fig = make_subplots(rows=1, cols=2)1216fig.add_trace(1217Scatter(y=numpy.array(self.__episode_rewards).cumsum(), name="cumulative reward"),1218row=1,1219col=1,1220)1221traces, layout = debug.network_as_plotly_traces(xref="x2", yref="y2")1222for t in traces:1223fig.add_trace(t, row=1, col=2)1224fig.update_layout(layout)1225return fig12261227def render(self, mode: str = "human") -> None:1228fig = self.render_as_fig()1229fig.show(renderer=self.__renderer)12301231def close(self) -> None:1232return None123312341235