Path: blob/main/cyberbattle/agents/baseline/agent_dql.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23# Function DeepQLearnerPolicy.optimize_model:4# Copyright (c) 2017, Pytorch contributors5# All rights reserved.6# https://github.com/pytorch/tutorials/blob/master/LICENSE78"""Deep Q-learning agent applied to chain network (notebook)9This notebooks can be run directly from VSCode, to generate a10traditional Jupyter Notebook to open in your browser11you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.1213Requirements:14Nvidia CUDA drivers for WSL2: https://docs.nvidia.com/cuda/wsl-user-guide/index.html15PyTorch16"""1718# pylint: disable=invalid-name1920# %% [markdown]21# # Chain network CyberBattle Gym played by a Deeo Q-learning agent2223# %%24from numpy import ndarray25from cyberbattle._env import cyberbattle_env26import numpy as np27from typing import List, NamedTuple, Optional, Tuple, Union28import random2930# deep learning packages31from torch import Tensor32import torch.nn.functional as F33import torch.optim as optim34import torch.nn as nn35import torch36import torch.cuda37from torch.nn.utils.clip_grad import clip_grad_norm_3839from .learner import Learner40from .agent_wrapper import EnvironmentBounds41import cyberbattle.agents.baseline.agent_wrapper as w42from .agent_randomcredlookup import CredentialCacheExploiter4344device = torch.device("cuda" if torch.cuda.is_available() else "cpu")454647class CyberBattleStateActionModel:48"""Define an abstraction of the state and action space49for a CyberBattle environment, to be used to train a Q-function.50"""5152def __init__(self, ep: EnvironmentBounds):53self.ep = ep5455self.global_features = w.ConcatFeatures(56ep,57[58# w.Feature_discovered_node_count(ep),59# w.Feature_owned_node_count(ep),60w.Feature_discovered_notowned_node_count(ep, None)61# w.Feature_discovered_ports(ep),62# w.Feature_discovered_ports_counts(ep),63# w.Feature_discovered_ports_sliding(ep),64# w.Feature_discovered_credential_count(ep),65# w.Feature_discovered_nodeproperties_sliding(ep),66],67)6869self.node_specific_features = w.ConcatFeatures(70ep,71[72# w.Feature_actions_tried_at_node(ep),73w.Feature_success_actions_at_node(ep),74w.Feature_failed_actions_at_node(ep),75w.Feature_active_node_properties(ep),76w.Feature_active_node_age(ep),77# w.Feature_active_node_id(ep)78],79)8081self.state_space = w.ConcatFeatures(82ep,83self.global_features.feature_selection + self.node_specific_features.feature_selection,84)8586self.action_space = w.AbstractAction(ep)8788def get_state_astensor(self, state: w.StateAugmentation):89state_vector = self.state_space.get(state, node=None)90state_vector_float = np.array(state_vector, dtype=np.float32)91state_tensor = torch.from_numpy(state_vector_float).unsqueeze(0)92return state_tensor9394def implement_action(95self,96wrapped_env: w.AgentWrapper,97actor_features: ndarray,98abstract_action: np.int32,99) -> Tuple[str, Optional[cyberbattle_env.Action], Optional[int]]:100"""Specialize an abstract model action into a CyberBattle gym action.101102actor_features -- the desired features of the actor to use (source CyberBattle node)103abstract_action -- the desired type of attack (connect, local, remote).104105Returns a gym environment implementing the desired attack at a node with the desired embedding.106"""107108observation = wrapped_env.state.observation109110# Pick source node at random (owned and with the desired feature encoding)111potential_source_nodes = [from_node for from_node in w.owned_nodes(observation) if np.all(actor_features == self.node_specific_features.get(wrapped_env.state, from_node))]112113if len(potential_source_nodes) > 0:114source_node = np.random.choice(potential_source_nodes)115116gym_action = self.action_space.specialize_to_gymaction(source_node, observation, np.int32(abstract_action))117118if not gym_action:119return "exploit[undefined]->explore", None, None120121elif wrapped_env.env.is_action_valid(gym_action, observation["action_mask"]):122return "exploit", gym_action, source_node123else:124return "exploit[invalid]->explore", None, None125else:126return "exploit[no_actor]->explore", None, None127128129# %%130131# Deep Q-learning132133134class Transition(NamedTuple):135"""One taken transition and its outcome"""136137state: Union[Tuple[Tensor], List[Tensor]]138action: Union[Tuple[Tensor], List[Tensor]]139next_state: Union[Tuple[Tensor], List[Tensor]]140reward: Union[Tuple[Tensor], List[Tensor]]141142143class ReplayMemory(object):144"""Transition replay memory"""145146def __init__(self, capacity):147self.capacity = capacity148self.memory = []149self.position = 0150151def push(self, *args):152"""Saves a transition."""153if len(self.memory) < self.capacity:154self.memory.append(None)155self.memory[self.position] = Transition(*args)156self.position = (self.position + 1) % self.capacity157158def sample(self, batch_size):159return random.sample(self.memory, batch_size)160161def __len__(self):162return len(self.memory)163164165class DQN(nn.Module):166"""The Deep Neural Network used to estimate the Q function"""167168def __init__(self, ep: EnvironmentBounds):169super(DQN, self).__init__()170171model = CyberBattleStateActionModel(ep)172linear_input_size = len(model.state_space.dim_sizes)173output_size = model.action_space.flat_size()174175self.hidden_layer1 = nn.Linear(linear_input_size, 1024)176# self.bn1 = nn.BatchNorm1d(256)177self.hidden_layer2 = nn.Linear(1024, 512)178self.hidden_layer3 = nn.Linear(512, 128)179# self.hidden_layer4 = nn.Linear(128, 64)180self.head = nn.Linear(128, output_size)181182# Called with either one element to determine next action, or a batch183# during optimization. Returns tensor([[left0exp,right0exp]...]).184def forward(self, x):185x = F.relu(self.hidden_layer1(x))186# x = F.dropout(x, p=0.5, training=self.training)187x = F.relu(self.hidden_layer2(x))188# x = F.dropout(x, p=0.5, training=self.training)189x = F.relu(self.hidden_layer3(x))190# x = F.relu(self.hidden_layer4(x))191return self.head(x.view(x.size(0), -1))192193194def random_argmax(array):195"""Just like `argmax` but if there are multiple elements with the max196return a random index to break ties instead of returning the first one."""197max_value = np.max(array)198max_index = np.where(array == max_value)[0]199200if max_index.shape[0] > 1:201max_index = int(np.random.choice(max_index, size=1))202else:203max_index = int(max_index)204205return max_value, max_index206207208class ChosenActionMetadata(NamedTuple):209"""Additonal info about the action chosen by the DQN-induced policy"""210211abstract_action: np.int32212actor_node: int213actor_features: ndarray214actor_state: ndarray215216def __repr__(self) -> str:217return f"[abstract_action={self.abstract_action}, actor={self.actor_node}, state={self.actor_state}]"218219220class DeepQLearnerPolicy(Learner):221"""Deep Q-Learning on CyberBattle environments222223Parameters224==========225ep -- global parameters of the environment226model -- define a state and action abstraction for the gym environment227gamma -- Q discount factor228replay_memory_size -- size of the replay memory229batch_size -- Deep Q-learning batch230target_update -- Deep Q-learning replay frequency (in number of episodes)231learning_rate -- the learning rate232233Parameters from DeepDoubleQ paper234- learning_rate = 0.00025235- linear epsilon decay236- gamma = 0.99237238Pytorch code from tutorial at239https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html240"""241242def __init__(243self,244ep: EnvironmentBounds,245gamma: float,246replay_memory_size: int,247target_update: int,248batch_size: int,249learning_rate: float,250):251self.stateaction_model = CyberBattleStateActionModel(ep)252self.batch_size = batch_size253self.gamma = gamma254self.learning_rate = learning_rate255256self.policy_net = DQN(ep).to(device)257self.target_net = DQN(ep).to(device)258self.target_net.load_state_dict(self.policy_net.state_dict())259self.target_net.eval()260self.target_update = target_update261262self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=learning_rate) # type: ignore263self.memory = ReplayMemory(replay_memory_size)264265self.credcache_policy = CredentialCacheExploiter()266267def parameters_as_string(self):268return f"γ={self.gamma}, lr={self.learning_rate}, replaymemory={self.memory.capacity},\n" f"batch={self.batch_size}, target_update={self.target_update}"269270def all_parameters_as_string(self) -> str:271model = self.stateaction_model272return (273f"{self.parameters_as_string()}\n"274f"dimension={model.state_space.flat_size()}x{model.action_space.flat_size()}, "275f"Q={[f.name() for f in model.state_space.feature_selection]} "276f"-> 'abstract_action'"277)278279def optimize_model(self, norm_clipping=False):280if len(self.memory) < self.batch_size:281return282283transitions = self.memory.sample(self.batch_size)284# converts batch-array of Transitions to Transition of batch-arrays.285batch = Transition(*zip(*transitions))286287# Compute a mask of non-final states and concatenate the batch elements288# (a final state would've been the one after which simulation ended)289non_final_mask = torch.tensor(290tuple(map((lambda s: s is not None), batch.next_state)),291device=device,292dtype=torch.bool,293)294non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])295state_batch = torch.cat(batch.state)296action_batch = torch.cat(batch.action)297reward_batch = torch.cat(batch.reward)298299# Compute Q(s_t, a) - the model computes Q(s_t), then we select the300# columns of actions taken. These are the actions which would've been taken301# for each batch state according to policy_net302# print(f'state_batch={state_batch.shape} input={len(self.stateaction_model.state_space.dim_sizes)}')303output = self.policy_net(state_batch)304305# print(f'output={output.shape} batch.action={transitions[0].action.shape} action_batch={action_batch.shape}')306state_action_values = output.gather(1, action_batch)307308# Compute V(s_{t+1}) for all next states.309# Expected values of actions for non_final_next_states are computed based310# on the "older" target_net; selecting their best reward with max(1)[0].311# This is merged based on the mask, such that we'll have either the expected312# state value or 0 in case the state was final.313next_state_values = torch.zeros(self.batch_size, device=device)314next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()315# Compute the expected Q values316expected_state_action_values = (next_state_values * self.gamma) + reward_batch317318# Compute Huber loss319loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))320321# Optimize the model322self.optimizer.zero_grad()323loss.backward()324325# Gradient clipping326if norm_clipping:327clip_grad_norm_(self.policy_net.parameters(), 1.0)328else:329for param in self.policy_net.parameters():330if param.grad is not None:331param.grad.data.clamp_(-1, 1)332self.optimizer.step()333334def get_actor_state_vector(self, global_state: ndarray, actor_features: ndarray) -> ndarray:335return np.concatenate(336(337np.array(global_state, dtype=np.float32),338np.array(actor_features, dtype=np.float32),339)340)341342def update_q_function(343self,344reward: float,345actor_state: ndarray,346abstract_action: np.int32,347next_actor_state: Optional[ndarray],348):349# store the transition in memory350reward_tensor = torch.tensor([reward], device=device, dtype=torch.float)351action_tensor = torch.tensor([[np.int_(abstract_action)]], device=device, dtype=torch.long)352current_state_tensor = torch.as_tensor(actor_state, dtype=torch.float, device=device).unsqueeze(0)353if next_actor_state is None:354next_state_tensor = None355else:356next_state_tensor = torch.as_tensor(next_actor_state, dtype=torch.float, device=device).unsqueeze(0)357self.memory.push(current_state_tensor, action_tensor, next_state_tensor, reward_tensor)358359# optimize the target network360self.optimize_model()361362def on_step(363self,364wrapped_env: w.AgentWrapper,365observation,366reward: float,367done: bool,368truncated: bool,369info,370action_metadata,371):372agent_state = wrapped_env.state373if done:374self.update_q_function(375reward,376actor_state=action_metadata.actor_state,377abstract_action=action_metadata.abstract_action,378next_actor_state=None,379)380else:381next_global_state = self.stateaction_model.global_features.get(agent_state, node=None)382next_actor_features = self.stateaction_model.node_specific_features.get(agent_state, action_metadata.actor_node)383next_actor_state = self.get_actor_state_vector(next_global_state, next_actor_features)384385self.update_q_function(386reward,387actor_state=action_metadata.actor_state,388abstract_action=action_metadata.abstract_action,389next_actor_state=next_actor_state,390)391392def end_of_episode(self, i_episode, t):393# Update the target network, copying all weights and biases in DQN394if i_episode % self.target_update == 0:395self.target_net.load_state_dict(self.policy_net.state_dict())396397def lookup_dqn(self, states_to_consider: List[ndarray]) -> Tuple[List[np.int32], List[np.int32]]:398"""Given a set of possible current states return:399- index, in the provided list, of the state that would yield the best possible outcome400- the best action to take in such a state"""401with torch.no_grad():402# t.max(1) will return largest column value of each row.403# second column on max result is index of where max element was404# found, so we pick action with the larger expected reward.405# action: np.int32 = self.policy_net(states_to_consider).max(1)[1].view(1, 1).item()406407state_batch = torch.tensor(states_to_consider).to(device)408dnn_output = self.policy_net(state_batch).max(1)409action_lookups = dnn_output[1].tolist()410expectedq_lookups = dnn_output[0].tolist()411412return action_lookups, expectedq_lookups413414def metadata_from_gymaction(self, wrapped_env, gym_action):415current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)416actor_node = cyberbattle_env.sourcenode_of_action(gym_action)417actor_features = self.stateaction_model.node_specific_features.get(wrapped_env.state, actor_node)418abstract_action = self.stateaction_model.action_space.abstract_from_gymaction(gym_action)419return ChosenActionMetadata(420abstract_action=abstract_action,421actor_node=actor_node,422actor_features=actor_features,423actor_state=self.get_actor_state_vector(current_global_state, actor_features),424)425426def explore(self, wrapped_env: w.AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:427"""Random exploration that avoids repeating actions previously taken in the same state"""428# sample local and remote actions only (excludes connect action)429gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])430metadata = self.metadata_from_gymaction(wrapped_env, gym_action)431return "explore", gym_action, metadata432433def try_exploit_at_candidate_actor_states(self, wrapped_env, current_global_state, actor_features, abstract_action):434actor_state = self.get_actor_state_vector(current_global_state, actor_features)435436action_style, gym_action, actor_node = self.stateaction_model.implement_action(wrapped_env, actor_features, abstract_action)437438if gym_action:439assert actor_node is not None, "actor_node should be set together with gym_action"440441return (442action_style,443gym_action,444ChosenActionMetadata(445abstract_action=abstract_action,446actor_node=actor_node,447actor_features=actor_features,448actor_state=actor_state,449),450)451else:452# learn the failed exploit attempt in the current state453self.update_q_function(454reward=0.0,455actor_state=actor_state,456next_actor_state=actor_state,457abstract_action=abstract_action,458)459460return "exploit[undefined]->explore", None, None461462def exploit(self, wrapped_env, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:463# first, attempt to exploit the credential cache464# using the crecache_policy465# action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)466# if gym_action:467# return action_style, gym_action, self.metadata_from_gymaction(wrapped_env, gym_action)468469# Otherwise on exploit learnt Q-function470471current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)472473# Gather the features of all the current active actors (i.e. owned nodes)474active_actors_features: List[ndarray] = [self.stateaction_model.node_specific_features.get(wrapped_env.state, from_node) for from_node in w.owned_nodes(observation)]475476unique_active_actors_features: List[ndarray] = list(np.unique(active_actors_features, axis=0))477478# array of actor state vector for every possible set of node features479candidate_actor_state_vector: List[ndarray] = [self.get_actor_state_vector(current_global_state, node_features) for node_features in unique_active_actors_features]480481remaining_action_lookups, remaining_expectedq_lookups = self.lookup_dqn(candidate_actor_state_vector)482remaining_candidate_indices = list(range(len(candidate_actor_state_vector)))483484while remaining_candidate_indices:485_, remaining_candidate_index = random_argmax(remaining_expectedq_lookups)486actor_index = remaining_candidate_indices[remaining_candidate_index]487abstract_action = remaining_action_lookups[remaining_candidate_index]488489actor_features = unique_active_actors_features[actor_index]490491action_style, gym_action, metadata = self.try_exploit_at_candidate_actor_states(wrapped_env, current_global_state, actor_features, abstract_action)492493if gym_action:494return action_style, gym_action, metadata495496remaining_candidate_indices.pop(remaining_candidate_index)497remaining_expectedq_lookups.pop(remaining_candidate_index)498remaining_action_lookups.pop(remaining_candidate_index)499500return "exploit[undefined]->explore", None, None501502def stateaction_as_string(self, action_metadata) -> str:503return ""504505506