Path: blob/main/cyberbattle/agents/baseline/agent_tabularqlearning.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Q-learning agent applied to chain network (notebook)4This notebooks can be run directly from VSCode, to generate a5traditional Jupyter Notebook to open in your browser6you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.7"""89# pylint: disable=invalid-name1011from typing import NamedTuple, Optional, Tuple12import numpy as np13import logging1415from cyberbattle._env import cyberbattle_env16from .agent_wrapper import EnvironmentBounds17from .agent_randomcredlookup import CredentialCacheExploiter18import cyberbattle.agents.baseline.agent_wrapper as w19import cyberbattle.agents.baseline.learner as learner202122def random_argmax(array):23"""Just like `argmax` but if there are multiple elements with the max24return a random index to break ties instead of returning the first one."""25max_value = np.max(array)26max_index = np.where(array == max_value)[0]2728if max_index.shape[0] > 1:29max_index = int(np.random.choice(max_index, size=1))30else:31max_index = int(max_index)3233return max_value, max_index343536def random_argtop_percentile(array: np.ndarray, percentile: float):37"""Just like `argmax` but if there are multiple elements with the max38return a random index to break ties instead of returning the first one."""39top_percentile = np.percentile(array, percentile)40indices = np.where(array >= top_percentile)[0]4142if len(indices) == 0:43return random_argmax(array)44elif indices.shape[0] > 1:45max_index = int(np.random.choice(indices, size=1))46else:47max_index = int(indices)4849return top_percentile, max_index505152class QMatrix:53"""Q-Learning matrix for a given state and action space54state_space - Features defining the state space55action_space - Features defining the action space56qm - Optional: initialization values for the Q matrix57"""5859# The Quality matrix60qm: np.ndarray6162def __init__(63self,64name,65state_space: w.Feature,66action_space: w.Feature,67qm: Optional[np.ndarray] = None,68):69"""Initialize the Q-matrix"""7071self.name = name72self.state_space = state_space73self.action_space = action_space74self.statedim = state_space.flat_size()75self.actiondim = action_space.flat_size()76self.qm = self.clear() if qm is None else qm7778# error calculated for the last update to the Q-matrix79self.last_error = 08081def shape(self):82return (self.statedim, self.actiondim)8384def clear(self):85"""Re-initialize the Q-matrix to 0"""86self.qm = np.zeros(shape=self.shape())87# self.qm = np.random.rand(*self.shape()) / 10088return self.qm8990def print(self):91print(f"[{self.name}]\n" f"state: {self.state_space}\n" f"action: {self.action_space}\n" f"shape = {self.shape()}")9293def update(94self,95current_state: int,96action: int,97next_state: int,98reward,99gamma,100learning_rate,101):102"""Update the Q matrix after taking `action` in state 'current_State'103and obtaining reward=R[current_state, action]"""104105maxq_atnext, max_index = random_argmax(self.qm[next_state,])106107# bellman equation for Q-learning108temporal_difference = reward + gamma * maxq_atnext - self.qm[current_state, action]109self.qm[current_state, action] += learning_rate * temporal_difference110111# The loss is calculated using the squared difference between112# target Q-Value and predicted Q-Value113square_error = temporal_difference * temporal_difference114self.last_error = square_error115116return self.qm[current_state, action]117118def exploit(self, features, percentile) -> Tuple[int, float]:119"""exploit: leverage the Q-matrix.120Returns the expected Q value and the chosen action."""121expected_q, action = random_argtop_percentile(self.qm[features, :], percentile)122return int(action), float(expected_q)123124125class QLearnAttackSource(QMatrix):126"""Top-level Q matrix to pick the attack127State space: global state info128Action space: feature encodings of suggested nodes129"""130131def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None):132self.ep = ep133134self.state_space = w.HashEncoding(135ep,136[137# Feature_discovered_node_count(),138# Feature_discovered_credential_count(),139w.Feature_discovered_ports_sliding(ep),140w.Feature_discovered_nodeproperties_sliding(ep),141w.Feature_discovered_notowned_node_count(ep, 3),142],1435000,144) # should not be too small, pick something big to avoid collision145146self.action_space = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep)])147148super().__init__("attack_source", self.state_space, self.action_space, qm)149150151class QLearnBestAttackAtSource(QMatrix):152"""Top-level Q matrix to pick the attack from a pre-chosen source node153State space: feature encodings of suggested node states154Action space: a SimpleAbstract action155"""156157def __init__(self, ep: EnvironmentBounds, qm: Optional[np.ndarray] = None) -> None:158self.state_space = w.HashEncoding(159ep,160[161w.Feature_active_node_properties(ep),162w.Feature_active_node_age(ep),163# w.Feature_actions_tried_at_node(ep)164],1657000,166)167168# NOTE: For debugging purpose it's convenient instead to use169# Ravel encoding for node properties170self.state_space_debugging = w.RavelEncoding(171ep,172[173w.HashEncoding(174ep,175[176# Feature_discovered_node_count(),177# Feature_discovered_credential_count(),178w.Feature_discovered_ports_sliding(ep),179w.Feature_discovered_nodeproperties_sliding(ep),180w.Feature_discovered_notowned_node_count(ep, 3),181],182100,183),184w.Feature_active_node_properties(ep),185],186)187188self.action_space = w.AbstractAction(ep)189190super().__init__("attack_at_source", self.state_space, self.action_space, qm)191192193# TODO: We should try scipy for sparse matrices and OpenBLAS (MKL Intel version of BLAS, faster than openBLAS) for numpy194195196# %%197class LossEval:198"""Loss evaluation for a Q-Learner,199learner -- The Q learner200"""201202def __init__(self, qmatrix: QMatrix):203self.qmatrix = qmatrix204self.this_episode = []205self.all_episodes = []206207def new_episode(self):208self.this_episode = []209210def end_of_iteration(self, t, done):211self.this_episode.append(self.qmatrix.last_error)212213def current_episode_loss(self):214return np.average(self.this_episode)215216def end_of_episode(self, i_episode, t):217"""Average out the overall loss for this episode"""218self.all_episodes.append(self.current_episode_loss())219220221class ChosenActionMetadata(NamedTuple):222"""Additional information associated with the action chosen by the agent"""223224Q_source_state: int225Q_source_expectedq: float226Q_attack_expectedq: float227source_node: int228source_node_encoding: int229abstract_action: np.int32230Q_attack_state: int231232233class QTabularLearner(learner.Learner):234"""Tabular Q-learning235236Parameters237==========238gamma -- discount factor239240learning_rate -- learning rate241242ep -- environment global properties243244trained -- another QTabularLearner that is pretrained to initialize the Q matrices from (referenced, not copied)245246exploit_percentile -- (experimental) Randomly pick actions above this percentile in the Q-matrix.247Setting 100 gives the argmax as in standard Q-learning.248249The idea is that a value less than 100 helps compensate for the250approximation made when updating the Q-matrix caused by251the abstraction of the action space (attack parameters are abstracted away252in the Q-matrix, and when an abstract action is picked, it253gets specialized via a random process.)254When running in non-learning mode (lr=0), setting this value too close to 100255may lead to get stuck, being more permissive (e.g. in the 80-90 range)256typically gives better results.257258"""259260def __init__(261self,262ep: EnvironmentBounds,263gamma: float,264learning_rate: float,265exploit_percentile: float,266trained=None, # : Optional[QTabularLearner]267):268if trained:269self.qsource = trained.qsource270self.qattack = trained.qattack271else:272self.qsource = QLearnAttackSource(ep)273self.qattack = QLearnBestAttackAtSource(ep)274275self.loss_qsource = LossEval(self.qsource)276self.loss_qattack = LossEval(self.qattack)277self.gamma = gamma278self.learning_rate = learning_rate279self.exploit_percentile = exploit_percentile280self.credcache_policy = CredentialCacheExploiter()281282def on_step(283self,284wrapped_env: w.AgentWrapper,285observation,286reward,287done,288truncated,289info,290action_metadata: ChosenActionMetadata,291):292agent_state = wrapped_env.state293294# Update the top-level Q matrix for the state of the selected source node295after_toplevel_state = self.qsource.state_space.encode(agent_state)296self.qsource.update(297action_metadata.Q_source_state,298action_metadata.source_node_encoding,299after_toplevel_state,300reward,301self.gamma,302self.learning_rate,303)304305# Update the second Q matrix for the abstract action chosen306qattack_state_after = self.qattack.state_space.encode_at(agent_state, action_metadata.source_node)307self.qattack.update(308action_metadata.Q_attack_state,309int(action_metadata.abstract_action),310qattack_state_after,311reward,312self.gamma,313self.learning_rate,314)315316def end_of_iteration(self, t, done):317self.loss_qsource.end_of_iteration(t, done)318self.loss_qattack.end_of_iteration(t, done)319320def end_of_episode(self, i_episode, t):321self.loss_qsource.end_of_episode(i_episode, t)322self.loss_qattack.end_of_episode(i_episode, t)323324def loss_as_string(self):325return f"[loss_source={self.loss_qsource.current_episode_loss():0.3f}" f" loss_attack={self.loss_qattack.current_episode_loss():0.3f}]"326327def new_episode(self):328self.loss_qsource.new_episode()329self.loss_qattack.new_episode()330331def exploit(self, wrapped_env: w.AgentWrapper, observation):332agent_state = wrapped_env.state333334qsource_state = self.qsource.state_space.encode(agent_state)335336#############337# first, attempt to exploit the credential cache338# using the crecache_policy339action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)340if gym_action:341source_node = cyberbattle_env.sourcenode_of_action(gym_action)342return (343action_style,344gym_action,345ChosenActionMetadata(346Q_source_state=qsource_state,347Q_source_expectedq=-1,348Q_attack_expectedq=-1,349source_node=source_node,350source_node_encoding=self.qsource.action_space.encode_at(agent_state, source_node),351abstract_action=np.int32(self.qattack.action_space.abstract_from_gymaction(gym_action)),352Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node),353),354)355#############356357# Pick action: pick random source state among the ones with the maximum Q-value358action_style = "exploit"359source_node_encoding, qsource_expectedq = self.qsource.exploit(qsource_state, percentile=100)360361# Pick source node at random (owned and with the desired feature encoding)362potential_source_nodes = [from_node for from_node in w.owned_nodes(observation) if source_node_encoding == self.qsource.action_space.encode_at(agent_state, from_node)]363364if len(potential_source_nodes) == 0:365logging.debug(f"No node with encoding {source_node_encoding}, fallback on explore")366# NOTE: we should make sure that it does not happen too often,367# the penalty should be much smaller than typical rewards, small nudge368# not a new feedback signal.369370# Learn the lack of node availability371self.qsource.update(372qsource_state,373source_node_encoding,374qsource_state,375reward=0,376gamma=self.gamma,377learning_rate=self.learning_rate,378)379380return "exploit-1->explore", None, None381else:382source_node = np.random.choice(potential_source_nodes)383384qattack_state = self.qattack.state_space.encode_at(agent_state, source_node)385386abstract_action, qattack_expectedq = self.qattack.exploit(qattack_state, percentile=self.exploit_percentile)387388gym_action = self.qattack.action_space.specialize_to_gymaction(source_node, observation, np.int32(abstract_action))389390assert int(abstract_action) < self.qattack.action_space.flat_size(), f"abstract_action={abstract_action} gym_action={gym_action}"391392if gym_action and wrapped_env.env.is_action_valid(gym_action, observation["action_mask"]):393logging.debug(f" exploit gym_action={gym_action} source_node_encoding={source_node_encoding}")394return (395action_style,396gym_action,397ChosenActionMetadata(398Q_source_state=qsource_state,399Q_source_expectedq=qsource_expectedq,400Q_attack_expectedq=qsource_expectedq,401source_node=source_node,402source_node_encoding=source_node_encoding,403abstract_action=np.int32(abstract_action),404Q_attack_state=qattack_state,405),406)407else:408# NOTE: We should make the penalty reward smaller than409# the average/typical non-zero reward of the env (e.g. 1/1000 smaller)410# The idea of weighing the learning_rate when taking a chance is411# related to "Inverse propensity weighting"412413# Learn the non-validity of the action414self.qsource.update(415qsource_state,416source_node_encoding,417qsource_state,418reward=0,419gamma=self.gamma,420learning_rate=self.learning_rate,421)422423self.qattack.update(424qattack_state,425int(abstract_action),426qattack_state,427reward=0,428gamma=self.gamma,429learning_rate=self.learning_rate,430)431432# fallback on random exploration433return (434("exploit[invalid]->explore" if gym_action else "exploit[undefined]->explore"),435None,436None,437)438439def explore(self, wrapped_env: w.AgentWrapper):440agent_state = wrapped_env.state441gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])442abstract_action = self.qattack.action_space.abstract_from_gymaction(gym_action)443444assert int(abstract_action) < self.qattack.action_space.flat_size(), f"Q_attack_action={abstract_action} gym_action={gym_action}"445446source_node = cyberbattle_env.sourcenode_of_action(gym_action)447448return (449"explore",450gym_action,451ChosenActionMetadata(452Q_source_state=self.qsource.state_space.encode(agent_state),453Q_source_expectedq=-1,454Q_attack_expectedq=-1,455source_node=source_node,456source_node_encoding=self.qsource.action_space.encode_at(agent_state, source_node),457abstract_action=abstract_action,458Q_attack_state=self.qattack.state_space.encode_at(agent_state, source_node),459),460)461462def stateaction_as_string(self, action_metadata) -> str:463return (464f"Qsource[state={action_metadata.Q_source_state} err={self.qsource.last_error:0.2f}"465f"Q={action_metadata.Q_source_expectedq:.2f}] "466f"Qattack[state={action_metadata.Q_attack_state} err={self.qattack.last_error:0.2f} "467f"Q={action_metadata.Q_attack_expectedq:.2f}] "468)469470def parameters_as_string(self) -> str:471return f"γ={self.gamma}," f"learning_rate={self.learning_rate}," f"Q%={self.exploit_percentile}"472473def all_parameters_as_string(self) -> str:474return (475f" dimension={self.qsource.state_space.flat_size()}x{self.qsource.action_space.flat_size()},"476f"{self.qattack.state_space.flat_size()}x{self.qattack.action_space.flat_size()}\n"477f"Q1={[f.name() for f in self.qsource.state_space.feature_selection]}"478f" -> {[f.name() for f in self.qsource.action_space.feature_selection]}\n"479f"Q2={[f.name() for f in self.qattack.state_space.feature_selection]} -> 'action'"480)481482483