Path: blob/main/cyberbattle/agents/baseline/learner.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Learner helpers and epsilon greedy search"""45import math6import sys78from .plotting import PlotTraining, plot_averaged_cummulative_rewards9from .agent_wrapper import (10AgentWrapper,11EnvironmentBounds,12Verbosity,13ActionTrackingStateAugmentation,14)15import logging16import numpy as np17from cyberbattle._env import cyberbattle_env18from typing import Tuple, Optional, TypedDict, List19import progressbar20import abc212223class Learner(abc.ABC):24"""Interface to be implemented by an epsilon-greedy learner"""2526def new_episode(self) -> None:27return None2829def end_of_episode(self, i_episode, t) -> None:30return None3132def end_of_iteration(self, t, done) -> None:33return None3435@abc.abstractmethod36def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:37"""Exploration function.38Returns (action_type, gym_action, action_metadata) where39action_metadata is a custom object that gets passed to the on_step callback function"""40raise NotImplementedError4142@abc.abstractmethod43def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:44"""Exploit function.45Returns (action_type, gym_action, action_metadata) where46action_metadata is a custom object that gets passed to the on_step callback function"""47raise NotImplementedError4849@abc.abstractmethod50def on_step(51self,52wrapped_env: AgentWrapper,53observation,54reward,55done,56truncated,57info,58action_metadata,59) -> None:60raise NotImplementedError6162def parameters_as_string(self) -> str:63return ""6465def all_parameters_as_string(self) -> str:66return ""6768def loss_as_string(self) -> str:69return ""7071def stateaction_as_string(self, action_metadata) -> str:72return ""737475class RandomPolicy(Learner):76"""A policy that does not learn and only explore"""7778def explore(self, wrapped_env: AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:79gym_action = wrapped_env.env.sample_valid_action()80return "explore", gym_action, None8182def exploit(self, wrapped_env: AgentWrapper, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:83raise NotImplementedError8485def on_step(86self,87wrapped_env: AgentWrapper,88observation,89reward,90done,91truncated,92info,93action_metadata,94):95return None969798Breakdown = TypedDict("Breakdown", {"local": int, "remote": int, "connect": int})99100Outcomes = TypedDict("Outcomes", {"reward": Breakdown, "noreward": Breakdown})101102Stats = TypedDict(103"Stats",104{"exploit": Outcomes, "explore": Outcomes, "exploit_deflected_to_explore": int},105)106107TrainedLearner = TypedDict(108"TrainedLearner",109{110"all_episodes_rewards": List[List[float]],111"all_episodes_availability": List[List[float]],112"learner": Learner,113"trained_on": str,114"title": str,115},116)117118119def print_stats(stats):120"""Print learning statistics"""121122def print_breakdown(stats, actiontype: str):123def ratio(kind: str) -> str:124x, y = (125stats[actiontype]["reward"][kind],126stats[actiontype]["noreward"][kind],127)128sum = x + y129if sum == 0:130return "NaN"131else:132return f"{(x / sum):.2f}"133134def print_kind(kind: str):135print(f" {actiontype}-{kind}: {stats[actiontype]['reward'][kind]}/{stats[actiontype]['noreward'][kind]} " f"({ratio(kind)})")136137print_kind("local")138print_kind("remote")139print_kind("connect")140141print(" Breakdown [Reward/NoReward (Success rate)]")142print_breakdown(stats, "explore")143print_breakdown(stats, "exploit")144print(f" exploit deflected to exploration: {stats['exploit_deflected_to_explore']}")145146147def epsilon_greedy_search(148cyberbattle_gym_env: cyberbattle_env.CyberBattleEnv,149environment_properties: EnvironmentBounds,150learner: Learner,151title: str,152episode_count: int,153iteration_count: int,154epsilon: float,155epsilon_minimum=0.0,156epsilon_multdecay: Optional[float] = None,157epsilon_exponential_decay: Optional[int] = None,158render=True,159render_last_episode_rewards_to: Optional[str] = None,160verbosity: Verbosity = Verbosity.Normal,161plot_episodes_length=True,162) -> TrainedLearner:163"""Epsilon greedy search for CyberBattle gym environments164165Parameters166==========167168- cyberbattle_gym_env -- the CyberBattle environment to train on169170- learner --- the policy learner/exploiter171172- episode_count -- Number of training episodes173174- iteration_count -- Maximum number of iterations in each episode175176- epsilon -- explore vs exploit177- 0.0 to exploit the learnt policy only without exploration178- 1.0 to explore purely randomly179180- epsilon_minimum -- epsilon decay clipped at this value.181Setting this value too close to 0 may leed the search to get stuck.182183- epsilon_decay -- epsilon gets multiplied by this value after each episode184185- epsilon_exponential_decay - if set use exponential decay. The bigger the value186is, the slower it takes to get from the initial `epsilon` to `epsilon_minimum`.187188- verbosity -- verbosity of the `print` logging189190- render -- render the environment interactively after each episode191192- render_last_episode_rewards_to -- render the environment to the specified file path193with an index appended to it each time there is a positive reward194for the last episode only195196- plot_episodes_length -- Plot the graph showing total number of steps by episode197at th end of the search.198199Note on convergence200===================201202Setting 'minimum_espilon' to 0 with an exponential decay <1203makes the learning converge quickly (loss function getting to 0),204but that's just a forced convergence, however, since when205epsilon approaches 0, only the q-values that were explored so206far get updated and so only that subset of cells from207the Q-matrix converges.208209"""210211print(212f"###### {title}\n"213f"Learning with: episode_count={episode_count},"214f"iteration_count={iteration_count},"215f"ϵ={epsilon},"216f"ϵ_min={epsilon_minimum}, "217+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else "")218+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else "")219+ f"{learner.parameters_as_string()}"220)221222initial_epsilon = epsilon223224all_episodes_rewards = []225all_episodes_availability = []226227o, _ = cyberbattle_gym_env.reset()228wrapped_env = AgentWrapper(229cyberbattle_gym_env,230ActionTrackingStateAugmentation(environment_properties, o),231)232steps_done = 0233plot_title = (234f"{title} (epochs={episode_count}, ϵ={initial_epsilon}, ϵ_min={epsilon_minimum},"235+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else "")236+ (f"ϵ_expdecay={epsilon_exponential_decay}," if epsilon_exponential_decay else "")237+ learner.parameters_as_string()238)239plottraining = PlotTraining(title=plot_title, render_each_episode=render)240241render_file_index = 1242243for i_episode in range(1, episode_count + 1):244print(f" ## Episode: {i_episode}/{episode_count} '{title}' " f"ϵ={epsilon:.4f}, " f"{learner.parameters_as_string()}")245246observation, _ = wrapped_env.reset()247total_reward = 0.0248all_rewards = []249all_availability = []250learner.new_episode()251252stats = Stats(253exploit=Outcomes(254reward=Breakdown(local=0, remote=0, connect=0),255noreward=Breakdown(local=0, remote=0, connect=0),256),257explore=Outcomes(258reward=Breakdown(local=0, remote=0, connect=0),259noreward=Breakdown(local=0, remote=0, connect=0),260),261exploit_deflected_to_explore=0,262)263264episode_ended_at = None265sys.stdout.flush()266267bar = progressbar.ProgressBar(268widgets=[269"Episode ",270f"{i_episode}",271"|Iteration ",272progressbar.Counter(),273"|",274progressbar.Variable(name="reward", width=6, precision=10),275"|",276progressbar.Variable(name="last_reward_at", width=4),277"|",278progressbar.Timer(),279progressbar.Bar(),280],281redirect_stdout=False,282)283284for t in bar(range(1, 1 + iteration_count)):285if epsilon_exponential_decay:286epsilon = epsilon_minimum + math.exp(-1.0 * steps_done / epsilon_exponential_decay) * (initial_epsilon - epsilon_minimum)287288steps_done += 1289290x = np.random.rand()291if x <= epsilon:292action_style, gym_action, action_metadata = learner.explore(wrapped_env)293else:294action_style, gym_action, action_metadata = learner.exploit(wrapped_env, observation)295if not gym_action:296stats["exploit_deflected_to_explore"] += 1297_, gym_action, action_metadata = learner.explore(wrapped_env)298299# Take the step300logging.debug(f"gym_action={gym_action}, action_metadata={action_metadata}")301observation, reward, done, truncated, info = wrapped_env.step(gym_action)302303action_type = "exploit" if action_style == "exploit" else "explore"304outcome = "reward" if reward > 0 else "noreward"305if "local_vulnerability" in gym_action:306stats[action_type][outcome]["local"] += 1307elif "remote_vulnerability" in gym_action:308stats[action_type][outcome]["remote"] += 1309else:310stats[action_type][outcome]["connect"] += 1311312learner.on_step(wrapped_env, observation, reward, done, truncated, info, action_metadata)313assert np.shape(reward) == ()314315all_rewards.append(reward)316all_availability.append(info["network_availability"])317total_reward += reward318bar.update(t, reward=total_reward)319if reward > 0:320bar.update(t, last_reward_at=t)321322if verbosity == Verbosity.Verbose or (verbosity == Verbosity.Normal and reward > 0):323sign = ["-", "+"][reward > 0]324325print(326f" {sign} t={t} {action_style} r={reward} cum_reward:{total_reward} "327f"a={action_metadata}-{gym_action} "328f"creds={len(observation['credential_cache_matrix'])} "329f" {learner.stateaction_as_string(action_metadata)}"330)331332if i_episode == episode_count and render_last_episode_rewards_to is not None and reward > 0:333fig = cyberbattle_gym_env.render_as_fig()334fig.write_image(f"{render_last_episode_rewards_to}-e{i_episode}-{render_file_index}.png")335render_file_index += 1336337learner.end_of_iteration(t, done)338339if done:340episode_ended_at = t341bar.finish(dirty=True)342break343344sys.stdout.flush()345346loss_string = learner.loss_as_string()347if loss_string:348loss_string = "loss={loss_string}"349350if episode_ended_at:351print(f" Episode {i_episode} ended at t={episode_ended_at} {loss_string}")352else:353print(f" Episode {i_episode} stopped at t={iteration_count} {loss_string}")354355print_stats(stats)356357all_episodes_rewards.append(all_rewards)358all_episodes_availability.append(all_availability)359360length = episode_ended_at if episode_ended_at else iteration_count361learner.end_of_episode(i_episode=i_episode, t=length)362if plot_episodes_length:363plottraining.episode_done(length)364if render:365wrapped_env.render()366367if epsilon_multdecay:368epsilon = max(epsilon_minimum, epsilon * epsilon_multdecay)369370wrapped_env.close()371print("simulation ended")372if plot_episodes_length:373plottraining.plot_end()374375return TrainedLearner(376all_episodes_rewards=all_episodes_rewards,377all_episodes_availability=all_episodes_availability,378learner=learner,379trained_on=cyberbattle_gym_env.name,380title=plot_title,381)382383384def transfer_learning_evaluation(385environment_properties: EnvironmentBounds,386trained_learner: TrainedLearner,387eval_env: cyberbattle_env.CyberBattleEnv,388eval_epsilon: float,389eval_episode_count: int,390iteration_count: int,391benchmark_policy: Learner = RandomPolicy(),392benchmark_training_args=dict(title="Benchmark", epsilon=1.0),393):394"""Evaluated a trained agent on another environment of different size"""395396eval_oneshot_all = epsilon_greedy_search(397eval_env,398environment_properties,399learner=trained_learner["learner"],400episode_count=eval_episode_count, # one shot from learnt Q matric401iteration_count=iteration_count,402epsilon=eval_epsilon,403render=False,404verbosity=Verbosity.Quiet,405title=f"One shot on {eval_env.name} - Trained on {trained_learner['trained_on']}",406)407408eval_random = epsilon_greedy_search(409eval_env,410environment_properties,411learner=benchmark_policy,412episode_count=eval_episode_count,413iteration_count=iteration_count,414render=False,415verbosity=Verbosity.Quiet,416**benchmark_training_args,417)418419plot_averaged_cummulative_rewards(420all_runs=[eval_oneshot_all, eval_random],421title=f"Transfer learning {trained_learner['trained_on']}->{eval_env.name} "422f'-- max_nodes={environment_properties.maximum_node_count}, '423f'episodes={eval_episode_count},\n'424f"{trained_learner['learner'].all_parameters_as_string()}",425)426427428