Path: blob/main/cyberbattle/agents/baseline/agent_wrapper.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Agent wrapper for CyberBattle envrionments exposing additional4features extracted from the environment observations"""56from abc import abstractmethod7from cyberbattle._env.cyberbattle_env import EnvironmentBounds8from typing import Optional, List, Tuple, overload9import enum10import numpy as np11from gymnasium import spaces, Wrapper12from numpy import ndarray13import cyberbattle._env.cyberbattle_env as cyberbattle_env14import logging151617class StateAugmentation:18"""Default agent state augmentation, consisting of the gym environment19observation itself and nothing more."""2021def __init__(self, observation: cyberbattle_env.Observation):22self.observation = observation2324def on_step(25self,26action: cyberbattle_env.Action,27reward: float,28truncated: bool,29done: bool,30observation: cyberbattle_env.Observation,31):32self.observation = observation3334def on_reset(self, observation: cyberbattle_env.Observation):35self.observation = observation363738# Abstract class for a feature (either global or node-specific)39class Feature(spaces.MultiDiscrete):40"""41Feature consisting of multiple discrete dimensions.42Parameters:43nvec: is a vector defining the number of possible values44for each discrete space.45"""4647def __init__(self, env_properties: EnvironmentBounds, nvec):48self.env_properties = env_properties49super().__init__(nvec)5051def flat_size(self):52return np.prod(self.nvec, dtype=int)5354def name(self):55"""Return the name of the feature"""56p = len(type(Feature(self.env_properties, [])).__name__) + 157return type(self).__name__[p:]5859def pretty_print(self, v):60return v6162@abstractmethod63def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:64"""Compute the current value of a feature value at65the current observation and specific node"""66raise NotImplementedError676869class NodeFeature(Feature):70"""71Feature consisting of multiple discrete dimensions at a specific node.72"""7374@abstractmethod75def get_at(self, a: StateAugmentation, node: int) -> np.ndarray:76"""Compute the current value of a feature value at77the current observation and specific node"""78raise NotImplementedError7980def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:81assert node is not None, "feature only valid in the context of a node"82return self.get_at(a, node)838485class GlobalFeature(Feature):86"""87Feature consisting of multiple discrete dimensions at the global level.88"""8990@abstractmethod91def get_global(self, a: StateAugmentation) -> np.ndarray:92"""Compute the current value of a feature value at93the current observation"""94raise NotImplementedError9596def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:97assert node is None, "feature only valid in the context of a node"98return self.get_global(a)99100# @staticmethod101# def get_feature_value(102# f: Union[NodeFeature, GlobalFeature], a: SA_T, node: Optional[int]103# ):104# """Return the feature value at the current observation and specific node"""105# if isinstance(f, NodeFeature):106# assert node is not None, "feature only valid in the context of a node"107# return f.get(a, node)108# elif isinstance(f, GlobalFeature):109# assert node is None, "feature only valid in the context of a node"110# return f.get(a)111112113class Feature_active_node_properties(NodeFeature):114"""Bitmask of all properties set for the active node"""115116def __init__(self, p: EnvironmentBounds):117super().__init__(p, [2] * p.property_count)118119def get_at(self, a: StateAugmentation, node) -> ndarray:120assert node is not None, "feature only valid in the context of a node"121122node_prop = a.observation["discovered_nodes_properties"]123124# list of all properties set/unset on the node125assert node < len(node_prop), f"invalid node index {node} (not discovered yet)"126127# Remap to get rid of the unknown value (2):128# 1->1, 0->0, 2->0129remapped = np.array(node_prop[node] % 2, dtype=np.int_)130return remapped131132133class Feature_active_node_age(NodeFeature):134"""How recently was this node discovered?135(measured by reverse position in the list of discovered nodes)"""136137def __init__(self, p: EnvironmentBounds):138super().__init__(p, [p.maximum_node_count])139140def get_at(self, a: StateAugmentation, node) -> ndarray:141assert node is not None, "feature only valid in the context of a node"142143discovered_node_count = a.observation["discovered_node_count"]144145assert node < discovered_node_count, f"invalid node index {node} (not discovered yet)"146147return np.array([discovered_node_count - node - 1], dtype=np.int_)148149150class Feature_active_node_id(NodeFeature):151"""Return the node id itself"""152153def __init__(self, p: EnvironmentBounds):154super().__init__(p, [p.maximum_node_count] * 1)155156def get_at(self, a: StateAugmentation, node) -> ndarray:157return np.array([node], dtype=np.int_)158159160class Feature_discovered_nodeproperties_sliding(GlobalFeature):161"""Bitmask indicating node properties seen in last few cache entries"""162163window_size = 3164165def __init__(self, p: EnvironmentBounds):166super().__init__(p, [2] * p.property_count)167168def get_global(self, a: StateAugmentation) -> ndarray:169n = a.observation["discovered_node_count"]170node_prop = a.observation["discovered_nodes_properties"][:n]171172# keep last window of entries173node_prop_window = node_prop[-self.window_size :, :]174175# Remap to get rid of the unknown value (2)176node_prop_window_remapped = np.int32(node_prop_window % 2)177178countby = np.sum(node_prop_window_remapped, axis=0)179180bitmask = (countby > 0) * 1181return bitmask182183184class Feature_discovered_ports(GlobalFeature):185"""Bitmask vector indicating each port seen so far in discovered credentials"""186187def __init__(self, p: EnvironmentBounds):188super().__init__(p, [2] * p.port_count)189190def get_global(self, a: StateAugmentation):191n = a.observation["credential_cache_length"]192known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)193if n > 0:194ccm = np.array(a.observation["credential_cache_matrix"])[:n]195known_credports[np.int32(ccm[:, 1])] = 1196return known_credports197198199class Feature_discovered_ports_sliding(GlobalFeature):200"""Bitmask indicating port seen in last few cache entries"""201202window_size = 3203204def __init__(self, p: EnvironmentBounds):205super().__init__(p, [2] * p.port_count)206207def get_global(self, a: StateAugmentation):208known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32)209n = a.observation["credential_cache_length"]210if n > 0:211ccm = np.array(a.observation["credential_cache_matrix"])[:n]212known_credports[np.int32(ccm[-self.window_size :, 1])] = 1213return known_credports214215216class Feature_discovered_ports_counts(GlobalFeature):217"""Count of each port seen so far in discovered credentials"""218219def __init__(self, p: EnvironmentBounds):220super().__init__(p, [p.maximum_total_credentials + 1] * p.port_count)221222def get_global(self, a: StateAugmentation):223n = a.observation["credential_cache_length"]224if n > 0:225ccm = np.array(a.observation["credential_cache_matrix"])[:n]226ports = np.int32(ccm[:, 1])227else:228ports = np.zeros(0)229return np.bincount(ports, minlength=self.env_properties.port_count)230231232class Feature_discovered_credential_count(GlobalFeature):233"""number of credentials discovered so far"""234235def __init__(self, p: EnvironmentBounds):236super().__init__(p, [p.maximum_total_credentials + 1])237238def get_global(self, a: StateAugmentation):239n = a.observation["credential_cache_length"]240return np.array([n], dtype=np.int_)241242243class Feature_discovered_node_count(GlobalFeature):244"""number of nodes discovered so far"""245246def __init__(self, p: EnvironmentBounds):247super().__init__(p, [p.maximum_node_count + 1])248249def get_global(self, a: StateAugmentation):250return np.array([a.observation["discovered_node_count"]], dtype=np.int_)251252253class Feature_discovered_notowned_node_count(GlobalFeature):254"""number of nodes discovered that are not owned yet (optionally clipped)"""255256def __init__(self, p: EnvironmentBounds, clip: Optional[int]):257self.clip = np.int32(clip or p.maximum_node_count)258super().__init__(p, [self.clip + 1])259260def get_global(self, a: StateAugmentation):261discovered = a.observation["discovered_node_count"]262node_props = np.array(a.observation["discovered_nodes_properties"][:discovered])263# here we assume that a node is owned just if all its properties are known264owned = np.count_nonzero(np.all(node_props != 2, axis=1))265diff = np.int32(discovered - owned)266return np.array( [np.min((diff, self.clip))], dtype=np.int32)267268269class Feature_owned_node_count(GlobalFeature):270"""number of owned nodes so far"""271272def __init__(self, p: EnvironmentBounds):273super().__init__(p, [p.maximum_node_count + 1])274275def get_global(self, a: StateAugmentation):276levels = a.observation["nodes_privilegelevel"]277owned_nodes_indices = np.where(levels > 0)[0]278return np.array([len(owned_nodes_indices)], dtype=np.int_)279280281class ConcatFeatures(Feature):282"""Concatenate a list of features into a single feature283Parameters:284feature_selection - a selection of features to combine285"""286287def __init__(288self,289p: EnvironmentBounds,290feature_selection: List[Feature],291):292self.feature_selection = feature_selection293self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])294super().__init__(p, [self.dim_sizes])295296def pretty_print(self, v):297return v298299def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray:300"""Return the feature vector"""301feature_vector = [f.get(a, node) for f in self.feature_selection]302303return np.concatenate(feature_vector)304305306class FeatureEncoder(Feature):307"""Encode a list of features as a unique index"""308309feature_selection: List[Feature]310311def vector_to_index(self, feature_vector: np.ndarray) -> int:312raise NotImplementedError313314def feature_vector_of_observation_at(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray:315"""Return the current feature vector"""316feature_vector = [f.get(a, node) for f in self.feature_selection]317# print(f'feature_vector={feature_vector} self.feature_selection={self.feature_selection}')318return np.concatenate(feature_vector)319320def feature_vector_of_observation(self, a: StateAugmentation):321return self.feature_vector_of_observation_at(a, None)322323def encode(self, a: StateAugmentation, node=None) -> int:324"""Return the index encoding of the feature"""325feature_vector_concat = self.feature_vector_of_observation_at(a, node)326return self.vector_to_index(feature_vector_concat)327328def encode_at(self, a: StateAugmentation, node: int) -> int:329"""Return the current feature vector encoding with a node context"""330feature_vector_concat = self.feature_vector_of_observation_at(a, node)331return self.vector_to_index(feature_vector_concat)332333def name(self):334"""Return a name for the feature encoding"""335n = ", ".join([f.name() for f in self.feature_selection])336return f"[{n}]"337338339class HashEncoding(FeatureEncoder):340"""Feature defined as a hash of another feature341Parameters:342feature_selection: a selection of features to combine343hash_dim: dimension after hashing with hash(str(feature_vector)) or -1 for no hashing344"""345346def __init__(347self,348p: EnvironmentBounds,349feature_selection: List[Feature],350hash_size: int,351):352self.feature_selection = feature_selection353self.hash_size = hash_size354super().__init__(p, [hash_size])355356def flat_size(self):357return self.hash_size358359def vector_to_index(self, feature_vector) -> int:360"""Hash the state vector"""361return hash(str(feature_vector)) % self.hash_size362363def pretty_print(self, v):364return f"#{v}"365366367class RavelEncoding(FeatureEncoder):368"""Combine a set of features into a single feature with a unique index369(calculated by raveling the original indices)370Parameters:371feature_selection - a selection of features to combine372"""373374def __init__(375self,376p: EnvironmentBounds,377feature_selection: List[Feature],378):379self.feature_selection = feature_selection380self.dim_sizes = np.concatenate([f.nvec for f in feature_selection])381self.ravelled_size: np.int64 = np.prod(self.dim_sizes)382assert np.shape(self.ravelled_size) == (), f"! {np.shape(self.ravelled_size)}"383super().__init__(p, [self.ravelled_size])384385def vector_to_index(self, feature_vector) -> int:386assert len(self.dim_sizes) == len(feature_vector), f"feature vector of size {len(feature_vector)}, " f"expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}"387index_intp = np.ravel_multi_index(list(feature_vector), list(self.dim_sizes))388index = index_intp.item()389assert index < self.ravelled_size, f"feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) " f"-> index={index}, max_index={self.ravelled_size-1})"390return index391392def unravel_index(self, index) -> Tuple:393return np.unravel_index(index, self.dim_sizes)394395def pretty_print(self, v):396return self.unravel_index(v)397398399def owned_nodes(observation):400"""Return the list of owned nodes"""401return np.nonzero(observation["nodes_privilegelevel"])[0]402403404def discovered_nodes_notowned(observation):405"""Return the list of discovered nodes that are not owned yet"""406return np.nonzero(observation["nodes_privilegelevel"] == 0)[0]407408409class AbstractAction(Feature):410"""An abstraction of the gym state space that reduces411the space dimension for learning use to just412- local_attack(vulnid) (source_node provided)413- remote_attack(vulnid) (source_node provided, target_node forgotten)414- connect(port) (source_node provided, target_node forgotten, credentials infered from cache)415"""416417def __init__(self, p: EnvironmentBounds):418self.n_local_actions = p.local_attacks_count419self.n_remote_actions = p.remote_attacks_count420self.n_connect_actions = p.port_count421self.n_actions = self.n_local_actions + self.n_remote_actions + self.n_connect_actions422super().__init__(p, [self.n_actions])423424def specialize_to_gymaction(self, source_node: np.int32, observation, abstract_action_index: np.int32) -> Optional[cyberbattle_env.Action]:425"""Specialize an abstract "q"-action into a gym action.426Return an adjustement weight (1.0 if the choice was deterministic, 1/n if a choice was made out of n)427and the gym action"""428429abstract_action_index_int = int(abstract_action_index)430431discovered_nodes_count = observation["discovered_node_count"]432433if abstract_action_index_int < self.n_local_actions:434vuln = abstract_action_index_int435return {"local_vulnerability": np.array([source_node, vuln])}436437abstract_action_index_int -= self.n_local_actions438if abstract_action_index_int < self.n_remote_actions:439vuln = abstract_action_index_int440441if discovered_nodes_count <= 1:442return None443444# NOTE: We can do better here than random pick: ultimately this445# should be learnt from target node properties446447# pick any node from the discovered ones448# excluding the source node itself449target = (source_node + 1 + np.random.choice(discovered_nodes_count - 1)) % discovered_nodes_count450451return {"remote_vulnerability": np.array([source_node, target, vuln])}452453abstract_action_index_int -= self.n_remote_actions454port = np.int32(abstract_action_index_int)455456n_discovered_creds = observation["credential_cache_length"]457if n_discovered_creds <= 0:458# no credential available in the cache: cannot poduce a valid connect action459return None460discovered_credentials = np.array(observation["credential_cache_matrix"])[:n_discovered_creds]461462nodes_not_owned = discovered_nodes_notowned(observation)463464# Pick a matching cred from the discovered_cred matrix465# (at random if more than one exist for this target port)466match_port = discovered_credentials[:, 1] == port467match_port_indices = np.where(match_port)[0]468469credential_indices_choices = [c for c in match_port_indices if discovered_credentials[c, 0] in nodes_not_owned]470471if credential_indices_choices:472logging.debug("found matching cred in the credential cache")473else:474logging.debug("no cred matching requested port, trying instead creds used to access other ports")475credential_indices_choices = [i for (i, n) in enumerate(discovered_credentials[:, 0]) if n in nodes_not_owned]476477if credential_indices_choices:478logging.debug("found cred in the credential cache without matching port name")479else:480logging.debug("no cred to use from the credential cache")481return None482483cred = np.int32(np.random.choice(credential_indices_choices))484target = np.int32(discovered_credentials[cred, 0])485return {"connect": np.array([source_node, target, port, cred], dtype=np.int32)}486487def abstract_from_gymaction(self, gym_action: cyberbattle_env.Action) -> np.int32:488"""Abstract a gym action into an action to be index in the Q-matrix"""489if "local_vulnerability" in gym_action:490return gym_action["local_vulnerability"][1]491elif "remote_vulnerability" in gym_action:492r = gym_action["remote_vulnerability"]493return self.n_local_actions + r[2]494495assert "connect" in gym_action496c = gym_action["connect"]497498a = self.n_local_actions + self.n_remote_actions + c[2]499assert a < self.n_actions500return np.int32(a)501502503class ActionTrackingStateAugmentation(StateAugmentation):504"""An agent state augmentation consisting of505the environment observation augmented with the following dynamic information:506- success_action_count: count of action taken and succeeded at the current node507- failed_action_count: count of action taken and failed at the current node508"""509510def __init__(self, p: EnvironmentBounds, observation: cyberbattle_env.Observation):511self.aa = AbstractAction(p)512self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)513self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)514self.env_properties = p515super().__init__(observation)516517def on_step(518self,519action: cyberbattle_env.Action,520reward: float,521truncated,522done: bool,523observation: cyberbattle_env.Observation,524):525node = cyberbattle_env.sourcenode_of_action(action)526abstract_action = self.aa.abstract_from_gymaction(action)527if reward > 0:528self.success_action_count[node, abstract_action] += 1529else:530self.failed_action_count[node, abstract_action] += 1531super().on_step(action, reward, done, truncated, observation)532533def on_reset(self, observation: cyberbattle_env.Observation):534p = self.env_properties535self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)536self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)537super().on_reset(observation)538539540class Feature_actions_tried_at_node(NodeFeature):541"""A bit mask indicating which actions were already tried542a the current node: 0 no tried, 1 tried"""543544def __init__(self, p: EnvironmentBounds):545super().__init__(p, [2] * AbstractAction(p).n_actions)546547@overload548def get_at(self, a: ActionTrackingStateAugmentation, node: int): ...549550@overload551def get_at(self, a: StateAugmentation, node: int): ...552553def get_at(self, a: StateAugmentation, node: int):554assert node is not None, "feature only valid in the context of a node"555assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type"556return np.array(557((a.failed_action_count[node, :] + a.success_action_count[node, :]) != 0) * 1,558dtype=np.int_,559)560561562class Feature_success_actions_at_node(NodeFeature):563"""number of time each action succeeded at a given node"""564565max_action_count = 100566567def __init__(self, p: EnvironmentBounds):568super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)569570@overload571def get_at(self, a: ActionTrackingStateAugmentation, node: int): ...572573@overload574def get_at(self, a: StateAugmentation, node: int): ...575576def get_at(self, a: StateAugmentation, node: int):577assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type"578579return np.minimum(a.success_action_count[node, :], self.max_action_count - 1)580581582class Feature_failed_actions_at_node(NodeFeature):583"""number of time each action failed at a given node"""584585max_action_count = 100586587def __init__(self, p: EnvironmentBounds):588super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions)589590def get_at(self, a: StateAugmentation, node: int):591assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type"592return np.minimum(a.failed_action_count[node, :], self.max_action_count - 1)593594595class Verbosity(enum.Enum):596"""Verbosity of the learning function"""597598Quiet = 0599Normal = 1600Verbose = 2601602603class AgentWrapper(Wrapper):604"""Gym wrapper to update the agent state on every step"""605606def __init__(self, env: cyberbattle_env.CyberBattleEnv, state: StateAugmentation):607super().__init__(env)608self.env = env609self.state = state610611def step(self, action: cyberbattle_env.Action): # type: ignore612observation, reward, done, truncated, info = self.env.step(action)613self.state.on_step(action, reward, done, truncated, observation)614return observation, reward, done, truncated, info615616def reset(self, **kwargs):617observation, info = self.env.reset(**kwargs)618self.state.on_reset(observation)619return observation, info620621622