Path: blob/main/notebooks/notebook_benchmark.py
597 views
# ---1# jupyter:2# jupytext:3# cell_metadata_filter: tags,-all4# cell_metadata_json: true5# formats: ipynb,py:percent6# text_representation:7# extension: .py8# format_name: percent9# format_version: '1.3'10# jupytext_version: 1.16.411# kernelspec:12# display_name: Python 3 (ipykernel)13# language: python14# name: python315# ---1617# %% {"tags": []}18# Copyright (c) Microsoft Corporation.19# Licensed under the MIT License.2021"""Benchmark all the baseline agents22on a given CyberBattleSim environment and compare23them to the dumb 'random agent' baseline.2425NOTE: You can run this `.py`-notebook directly from VSCode.26You can also generate a traditional Jupyter Notebook27using the VSCode command `Export Currenty Python File As Jupyter Notebook`.28"""2930# pylint: disable=invalid-name3132# %% {"tags": []}33import sys34import os35import logging36import gymnasium as gym37import cyberbattle.agents.baseline.learner as learner38import cyberbattle.agents.baseline.plotting as p39import cyberbattle.agents.baseline.agent_wrapper as w40import cyberbattle.agents.baseline.agent_randomcredlookup as rca41import cyberbattle.agents.baseline.agent_tabularqlearning as tqa42import cyberbattle.agents.baseline.agent_dql as dqla43from cyberbattle.agents.baseline.agent_wrapper import Verbosity44from cyberbattle._env.cyberbattle_env import CyberBattleEnv4546logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")47# %% {"tags": []}48# %matplotlib inline49# %% {"tags": ["parameters"]}50# Papermill notebook parameters51gymid = "CyberBattleChain-v0"52env_size = 1053iteration_count = 900054training_episode_count = 5055eval_episode_count = 556maximum_node_count = 2257maximum_total_credentials = 2258plots_dir = "output/plots"5960# %% {"tags": []}61os.makedirs(plots_dir, exist_ok=True)6263# Load the Gym environment64if env_size:65_gym_env = gym.make(gymid, size=env_size)66else:67_gym_env = gym.make(gymid)6869from typing import cast7071gym_env = cast(CyberBattleEnv, _gym_env.unwrapped)72assert isinstance(gym_env, CyberBattleEnv), f"Expected CyberBattleEnv, got {type(gym_env)}"7374ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=maximum_node_count, maximum_total_credentials=maximum_total_credentials, identifiers=gym_env.identifiers)7576# %% {"tags": []}77debugging = False78if debugging:79print(f"port_count = {ep.port_count}, property_count = {ep.property_count}")8081gym_env.environment82# training_env.environment.plot_environment_graph()83gym_env.environment.network.nodes84gym_env.action_space85gym_env.action_space.sample()86gym_env.observation_space.sample()87o0, _ = gym_env.reset()88o_test, r, d, t, i = gym_env.step(gym_env.sample_valid_action())89o0, _ = gym_env.reset()9091o0.keys()9293fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])94a = w.StateAugmentation(o0)95w.Feature_discovered_ports(ep).get(a)96fe_example.encode_at(a, 0)9798# %% {"tags": []}99# Evaluate a random agent that opportunistically exploits100# credentials gathere in its local cache101credlookup_run = learner.epsilon_greedy_search(102gym_env,103ep,104learner=rca.CredentialCacheExploiter(),105episode_count=10,106iteration_count=iteration_count,107epsilon=0.90,108render=False,109epsilon_exponential_decay=10000,110epsilon_minimum=0.10,111verbosity=Verbosity.Quiet,112title="Credential lookups (ϵ-greedy)",113)114115# %% {"tags": []}116# Evaluate a Tabular Q-learning agent117tabularq_run = learner.epsilon_greedy_search(118gym_env,119ep,120learner=tqa.QTabularLearner(ep, gamma=0.015, learning_rate=0.01, exploit_percentile=100),121episode_count=training_episode_count,122iteration_count=iteration_count,123epsilon=0.90,124epsilon_exponential_decay=5000,125epsilon_minimum=0.01,126verbosity=Verbosity.Quiet,127render=False,128plot_episodes_length=False,129title="Tabular Q-learning",130)131132# %% {"tags": []}133# Evaluate an agent that exploits the Q-table learnt above134tabularq_exploit_run = learner.epsilon_greedy_search(135gym_env,136ep,137learner=tqa.QTabularLearner(ep, trained=tabularq_run["learner"], gamma=0.0, learning_rate=0.0, exploit_percentile=90),138episode_count=eval_episode_count,139iteration_count=iteration_count,140epsilon=0.0,141render=False,142verbosity=Verbosity.Quiet,143title="Exploiting Q-matrix",144)145146# %% {"tags": []}147# Evaluate the Deep Q-learning agent148dql_run = learner.epsilon_greedy_search(149cyberbattle_gym_env=gym_env,150environment_properties=ep,151learner=dqla.DeepQLearnerPolicy(152ep=ep,153gamma=0.015,154replay_memory_size=10000,155target_update=10,156batch_size=512,157# torch default learning rate is 1e-2158# a large value helps converge in less episodes159learning_rate=0.01,160),161episode_count=training_episode_count,162iteration_count=iteration_count,163epsilon=0.90,164epsilon_exponential_decay=5000,165epsilon_minimum=0.10,166verbosity=Verbosity.Quiet,167render=False,168plot_episodes_length=False,169title="DQL",170)171172# %% {"tags": []}173# Evaluate an agent that exploits the Q-function learnt above174dql_exploit_run = learner.epsilon_greedy_search(175gym_env,176ep,177learner=dql_run["learner"],178episode_count=eval_episode_count,179iteration_count=iteration_count,180epsilon=0.0,181epsilon_minimum=0.00,182render=False,183plot_episodes_length=False,184verbosity=Verbosity.Quiet,185title="Exploiting DQL",186)187188189# %% {"tags": []}190# Evaluate the random agent191random_run = learner.epsilon_greedy_search(192gym_env,193ep,194learner=learner.RandomPolicy(),195episode_count=eval_episode_count,196iteration_count=iteration_count,197epsilon=1.0, # purely random198render=False,199verbosity=Verbosity.Quiet,200plot_episodes_length=False,201title="Random search",202)203204# %% {"tags": []}205# Compare and plot results for all the agents206all_runs = [random_run, credlookup_run, tabularq_run, tabularq_exploit_run, dql_run, dql_exploit_run]207208# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit209themodel = dqla.CyberBattleStateActionModel(ep)210p.plot_averaged_cummulative_rewards(211all_runs=all_runs,212title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"213f"State: {[f.name() for f in themodel.state_space.feature_selection]} "214f"({len(themodel.state_space.feature_selection)}\n"215f"Action: abstract_action ({themodel.action_space.flat_size()})",216save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumrewards.png"),217)218219# %% {"tags": []}220contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]221p.plot_episodes_length(contenders)222p.plot_averaged_cummulative_rewards(title=f"Agent Benchmark top contenders\n" f"max_nodes:{ep.maximum_node_count}\n", all_runs=contenders,223save_at=os.path.join(plots_dir, f"benchmark-{gymid}-cumreward_contenders.png"))224225226# %% {"tags": []}227# Plot cumulative rewards for all episodes228for r in contenders:229p.plot_all_episodes(r)230231232