Path: blob/main/cyberbattle/agents/baseline/run.py
597 views
#!/usr/bin/python3.1012# Copyright (c) Microsoft Corporation.3# Licensed under the MIT License.45# -*- coding: utf-8 -*-6"""CLI to run the baseline Deep Q-learning and Random agents7on a sample CyberBattle gym environment and plot the respective8cummulative rewards in the terminal.910Example usage:1112python -m run --training_episode_count 50 --iteration_count 9000 --rewardplot_width 80 --chain_size=20 --ownership_goal 1.01314"""1516import torch17import gymnasium as gym18import logging19import sys20import asciichartpy21import argparse22import cyberbattle._env.cyberbattle_env as cyberbattle_env23from cyberbattle.agents.baseline.agent_wrapper import Verbosity24import cyberbattle.agents.baseline.agent_dql as dqla25import cyberbattle.agents.baseline.agent_wrapper as w26import cyberbattle.agents.baseline.plotting as p27import cyberbattle.agents.baseline.learner as learner28from typing import cast29from cyberbattle._env.cyberbattle_env import CyberBattleEnv3031parser = argparse.ArgumentParser(description="Run simulation with DQL baseline agent.")3233parser.add_argument("--training_episode_count", default=50, type=int, help="number of training epochs")3435parser.add_argument("--eval_episode_count", default=10, type=int, help="number of evaluation epochs")3637parser.add_argument("--iteration_count", default=9000, type=int, help="number of simulation iterations for each epoch")3839parser.add_argument("--reward_goal", default=2180, type=int, help="minimum target rewards to reach for the attacker to reach its goal")4041parser.add_argument("--ownership_goal", default=1.0, type=float, help="percentage of network nodes to own for the attacker to reach its goal")4243parser.add_argument("--rewardplot_width", default=80, type=int, help="width of the reward plot (values are averaged across iterations to fit in the desired width)")4445parser.add_argument("--chain_size", default=4, type=int, help="size of the chain of the CyberBattleChain sample environment")4647parser.add_argument("--random_agent", dest="run_random_agent", action="store_true", help="run the random agent as a baseline for comparison")48parser.add_argument("--no-random_agent", dest="run_random_agent", action="store_false", help="do not run the random agent as a baseline for comparison")49parser.set_defaults(run_random_agent=True)5051args = parser.parse_args()5253logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")5455print(f"torch cuda available={torch.cuda.is_available()}")5657cyberbattlechain = cast(58CyberBattleEnv, gym.make("CyberBattleChain-v0", size=args.chain_size, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=args.ownership_goal, reward=args.reward_goal))59)6061ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=22, maximum_node_count=22, identifiers=cyberbattlechain.identifiers)6263all_runs = []6465# Run Deep Q-learning66dqn_learning_run = learner.epsilon_greedy_search(67cyberbattle_gym_env=cyberbattlechain,68environment_properties=ep,69learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.015, replay_memory_size=10000, target_update=10, batch_size=512, learning_rate=0.01), # torch default is 1e-270episode_count=args.training_episode_count,71iteration_count=args.iteration_count,72epsilon=0.90,73render=True,74# epsilon_multdecay=0.75, # 0.999,75epsilon_exponential_decay=5000, # 1000076epsilon_minimum=0.10,77verbosity=Verbosity.Quiet,78title="DQL",79)8081all_runs.append(dqn_learning_run)8283if args.run_random_agent:84random_run = learner.epsilon_greedy_search(85cyberbattlechain,86ep,87learner=learner.RandomPolicy(),88episode_count=args.eval_episode_count,89iteration_count=args.iteration_count,90epsilon=1.0, # purely random91render=False,92verbosity=Verbosity.Quiet,93title="Random search",94)95all_runs.append(random_run)9697colors = [asciichartpy.red, asciichartpy.green, asciichartpy.yellow, asciichartpy.blue]9899print("Episode duration -- DQN=Red, Random=Green")100print(asciichartpy.plot(p.episodes_lengths_for_all_runs(all_runs), {"height": 30, "colors": colors}))101102print("Cumulative rewards -- DQN=Red, Random=Green")103c = p.averaged_cummulative_rewards(all_runs, args.rewardplot_width)104print(asciichartpy.plot(c, {"height": 10, "colors": colors}))105106107