Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/run.py
597 views
1
#!/usr/bin/python3.10
2
3
# Copyright (c) Microsoft Corporation.
4
# Licensed under the MIT License.
5
6
# -*- coding: utf-8 -*-
7
"""CLI to run the baseline Deep Q-learning and Random agents
8
on a sample CyberBattle gym environment and plot the respective
9
cummulative rewards in the terminal.
10
11
Example usage:
12
13
python -m run --training_episode_count 50 --iteration_count 9000 --rewardplot_width 80 --chain_size=20 --ownership_goal 1.0
14
15
"""
16
17
import torch
18
import gymnasium as gym
19
import logging
20
import sys
21
import asciichartpy
22
import argparse
23
import cyberbattle._env.cyberbattle_env as cyberbattle_env
24
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
25
import cyberbattle.agents.baseline.agent_dql as dqla
26
import cyberbattle.agents.baseline.agent_wrapper as w
27
import cyberbattle.agents.baseline.plotting as p
28
import cyberbattle.agents.baseline.learner as learner
29
from typing import cast
30
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
31
32
parser = argparse.ArgumentParser(description="Run simulation with DQL baseline agent.")
33
34
parser.add_argument("--training_episode_count", default=50, type=int, help="number of training epochs")
35
36
parser.add_argument("--eval_episode_count", default=10, type=int, help="number of evaluation epochs")
37
38
parser.add_argument("--iteration_count", default=9000, type=int, help="number of simulation iterations for each epoch")
39
40
parser.add_argument("--reward_goal", default=2180, type=int, help="minimum target rewards to reach for the attacker to reach its goal")
41
42
parser.add_argument("--ownership_goal", default=1.0, type=float, help="percentage of network nodes to own for the attacker to reach its goal")
43
44
parser.add_argument("--rewardplot_width", default=80, type=int, help="width of the reward plot (values are averaged across iterations to fit in the desired width)")
45
46
parser.add_argument("--chain_size", default=4, type=int, help="size of the chain of the CyberBattleChain sample environment")
47
48
parser.add_argument("--random_agent", dest="run_random_agent", action="store_true", help="run the random agent as a baseline for comparison")
49
parser.add_argument("--no-random_agent", dest="run_random_agent", action="store_false", help="do not run the random agent as a baseline for comparison")
50
parser.set_defaults(run_random_agent=True)
51
52
args = parser.parse_args()
53
54
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
55
56
print(f"torch cuda available={torch.cuda.is_available()}")
57
58
cyberbattlechain = cast(
59
CyberBattleEnv, gym.make("CyberBattleChain-v0", size=args.chain_size, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=args.ownership_goal, reward=args.reward_goal))
60
)
61
62
ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=22, maximum_node_count=22, identifiers=cyberbattlechain.identifiers)
63
64
all_runs = []
65
66
# Run Deep Q-learning
67
dqn_learning_run = learner.epsilon_greedy_search(
68
cyberbattle_gym_env=cyberbattlechain,
69
environment_properties=ep,
70
learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.015, replay_memory_size=10000, target_update=10, batch_size=512, learning_rate=0.01), # torch default is 1e-2
71
episode_count=args.training_episode_count,
72
iteration_count=args.iteration_count,
73
epsilon=0.90,
74
render=True,
75
# epsilon_multdecay=0.75, # 0.999,
76
epsilon_exponential_decay=5000, # 10000
77
epsilon_minimum=0.10,
78
verbosity=Verbosity.Quiet,
79
title="DQL",
80
)
81
82
all_runs.append(dqn_learning_run)
83
84
if args.run_random_agent:
85
random_run = learner.epsilon_greedy_search(
86
cyberbattlechain,
87
ep,
88
learner=learner.RandomPolicy(),
89
episode_count=args.eval_episode_count,
90
iteration_count=args.iteration_count,
91
epsilon=1.0, # purely random
92
render=False,
93
verbosity=Verbosity.Quiet,
94
title="Random search",
95
)
96
all_runs.append(random_run)
97
98
colors = [asciichartpy.red, asciichartpy.green, asciichartpy.yellow, asciichartpy.blue]
99
100
print("Episode duration -- DQN=Red, Random=Green")
101
print(asciichartpy.plot(p.episodes_lengths_for_all_runs(all_runs), {"height": 30, "colors": colors}))
102
103
print("Cumulative rewards -- DQN=Red, Random=Green")
104
c = p.averaged_cummulative_rewards(all_runs, args.rewardplot_width)
105
print(asciichartpy.plot(c, {"height": 10, "colors": colors}))
106
107