Path: blob/main/notebooks/notebook_dql_transfer.py
597 views
# ---1# jupyter:2# jupytext:3# cell_metadata_filter: title,-all4# formats: py:percent,ipynb5# text_representation:6# extension: .py7# format_name: percent8# format_version: '1.3'9# jupytext_version: 1.16.410# kernelspec:11# display_name: Python 3 (ipykernel)12# language: python13# name: python314# ---1516# %%17# Copyright (c) Microsoft Corporation.18# Licensed under the MIT License.1920# -*- coding: utf-8 -*-21# %%22"""Notebook demonstrating transfer learning capability of the23the Deep Q-learning agent trained and evaluated on the chain24environment of various sizes.2526NOTE: You can run this `.py`-notebook directly from VSCode.27You can also generate a traditional Jupyter Notebook28using the VSCode command `Export Currenty Python File As Jupyter Notebook`.29"""3031# %%32import os33import sys34import logging35import gymnasium as gym36import torch3738import cyberbattle.agents.baseline.learner as learner39import cyberbattle.agents.baseline.plotting as p40import cyberbattle.agents.baseline.agent_wrapper as w41import cyberbattle.agents.baseline.agent_dql as dqla42from cyberbattle.agents.baseline.agent_wrapper import Verbosity43import cyberbattle.agents.baseline.agent_randomcredlookup as rca44import importlib45import cyberbattle._env.cyberbattle_env as cyberbattle_env46import cyberbattle._env.cyberbattle_chain as cyberbattle_chain4748importlib.reload(learner)49importlib.reload(cyberbattle_env)50importlib.reload(cyberbattle_chain)5152logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")5354# %matplotlib inline5556# %%57torch.cuda.is_available()5859# %%60# To run once61# import plotly.io as pio62# pio.orca.config.use_xvfb = True63# pio.orca.config.save()64# %%65cyberbattlechain_4 = gym.make("CyberBattleChain-v0", size=4, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0)).unwrapped66cyberbattlechain_10 = gym.make("CyberBattleChain-v0", size=10, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0)).unwrapped67cyberbattlechain_20 = gym.make("CyberBattleChain-v0", size=20, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0)).unwrapped6869assert isinstance(cyberbattlechain_4, cyberbattle_env.CyberBattleEnv)70assert isinstance(cyberbattlechain_10, cyberbattle_env.CyberBattleEnv)71assert isinstance(cyberbattlechain_20, cyberbattle_env.CyberBattleEnv)7273ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=22, maximum_node_count=22, identifiers=cyberbattlechain_10.identifiers)7475# %% {"tags": ["parameters"]}76iteration_count = 900077training_episode_count = 5078eval_episode_count = 1079plots_dir = "output/images"8081# %%82os.makedirs(plots_dir, exist_ok=True)8384# %%85# Run Deep Q-learning86# 0.01587best_dqn_learning_run_10 = learner.epsilon_greedy_search(88cyberbattle_gym_env=cyberbattlechain_10,89environment_properties=ep,90learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.015, replay_memory_size=10000, target_update=10, batch_size=512, learning_rate=0.01), # torch default is 1e-291episode_count=training_episode_count,92iteration_count=iteration_count,93epsilon=0.90,94render=False,95# epsilon_multdecay=0.75, # 0.999,96epsilon_exponential_decay=5000, # 1000097epsilon_minimum=0.10,98verbosity=Verbosity.Quiet,99title="DQL",100)101102# %%103# %% Plot episode length104p.plot_episodes_length([best_dqn_learning_run_10])105106# %% [markdown]107108# %%109if not os.path.exists("images"):110os.mkdir("images")111112# %%113dql_exploit_run = learner.epsilon_greedy_search(114cyberbattlechain_10,115ep,116learner=best_dqn_learning_run_10["learner"],117episode_count=eval_episode_count,118iteration_count=iteration_count,119epsilon=0.0, # 0.35,120render=False,121render_last_episode_rewards_to=os.path.join(plots_dir, "dql_transfer-chain10"),122title="Exploiting DQL",123verbosity=Verbosity.Quiet,124)125126127# %%128random_run = learner.epsilon_greedy_search(129cyberbattlechain_10,130ep,131learner=learner.RandomPolicy(),132episode_count=eval_episode_count,133iteration_count=iteration_count,134epsilon=1.0, # purely random135render=False,136verbosity=Verbosity.Quiet,137title="Random search",138)139140# %%141# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit142themodel = dqla.CyberBattleStateActionModel(ep)143p.plot_averaged_cummulative_rewards(144all_runs=[best_dqn_learning_run_10, random_run, dql_exploit_run],145title=f"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n"146f"State: {[f.name() for f in themodel.state_space.feature_selection]} "147f"({len(themodel.state_space.feature_selection)}\n"148f"Action: abstract_action ({themodel.action_space.flat_size()})",149)150151152# %%153# plot cumulative rewards for all episodes154p.plot_all_episodes(best_dqn_learning_run_10)155156157##################################################158# %%159best_dqn_4 = learner.epsilon_greedy_search(160cyberbattle_gym_env=cyberbattlechain_4,161environment_properties=ep,162learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.15, replay_memory_size=10000, target_update=5, batch_size=256, learning_rate=0.01),163episode_count=training_episode_count,164iteration_count=iteration_count,165epsilon=0.90,166render=False,167epsilon_exponential_decay=5000,168epsilon_minimum=0.10,169verbosity=Verbosity.Quiet,170title="DQL",171)172173174# %%175learner.transfer_learning_evaluation(176environment_properties=ep,177trained_learner=best_dqn_learning_run_10,178eval_env=cyberbattlechain_20,179eval_epsilon=0.0, # alternate with exploration to help generalization to bigger network180eval_episode_count=eval_episode_count,181iteration_count=iteration_count,182benchmark_policy=rca.CredentialCacheExploiter(),183benchmark_training_args={"epsilon": 0.90, "epsilon_exponential_decay": 10000, "epsilon_minimum": 0.10, "title": "Credential lookups (ϵ-greedy)"},184)185learner.transfer_learning_evaluation(186environment_properties=ep,187trained_learner=best_dqn_4,188eval_env=cyberbattlechain_10,189eval_epsilon=0.0, # exploit Q-matrix only190eval_episode_count=eval_episode_count,191iteration_count=iteration_count,192benchmark_policy=rca.CredentialCacheExploiter(),193benchmark_training_args={"epsilon": 0.90, "epsilon_exponential_decay": 10000, "epsilon_minimum": 0.10, "title": "Credential lookups (ϵ-greedy)"},194)195196# %%197learner.transfer_learning_evaluation(198environment_properties=ep,199trained_learner=best_dqn_4,200eval_env=cyberbattlechain_20,201eval_epsilon=0.0, # exploit Q-matrix only202eval_episode_count=eval_episode_count,203iteration_count=iteration_count,204benchmark_policy=rca.CredentialCacheExploiter(),205benchmark_training_args={"epsilon": 0.90, "epsilon_exponential_decay": 10000, "epsilon_minimum": 0.10, "title": "Credential lookups (ϵ-greedy)"},206)207208209