Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/baseline_test.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
# -*- coding: utf-8 -*-
5
"""Test training of baseline agents."""
6
7
import torch
8
import gymnasium as gym
9
import logging
10
import sys
11
import cyberbattle._env.cyberbattle_env as cyberbattle_env
12
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
13
import cyberbattle.agents.baseline.agent_dql as dqla
14
import cyberbattle.agents.baseline.agent_wrapper as w
15
import cyberbattle.agents.baseline.learner as learner
16
import cyberbattle.agents.baseline.agent_tabularqlearning as tqa
17
from typing import cast
18
19
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
20
21
print(f"torch cuda available={torch.cuda.is_available()}")
22
23
cyberbattlechain = cast(cyberbattle_env.CyberBattleEnv, gym.make("CyberBattleChain-v0", size=4, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0, reward=100)))
24
25
ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=10, maximum_node_count=10, identifiers=cyberbattlechain.identifiers)
26
27
# %% {"tags": ["parameters"]}
28
training_episode_count = 2
29
iteration_count = 5
30
31
32
def test_agent_training() -> None:
33
dqn_learning_run = learner.epsilon_greedy_search(
34
cyberbattle_gym_env=cyberbattlechain,
35
environment_properties=ep,
36
learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.015, replay_memory_size=10000, target_update=10, batch_size=512, learning_rate=0.01), # torch default is 1e-2
37
episode_count=training_episode_count,
38
iteration_count=iteration_count,
39
epsilon=0.90,
40
render=False,
41
# epsilon_multdecay=0.75, # 0.999,
42
epsilon_exponential_decay=5000, # 10000
43
epsilon_minimum=0.10,
44
verbosity=Verbosity.Quiet,
45
title="DQL",
46
)
47
assert dqn_learning_run
48
49
random_run = learner.epsilon_greedy_search(
50
cyberbattlechain,
51
ep,
52
learner=learner.RandomPolicy(),
53
episode_count=training_episode_count,
54
iteration_count=iteration_count,
55
epsilon=1.0, # purely random
56
render=False,
57
verbosity=Verbosity.Quiet,
58
title="Random search",
59
)
60
61
assert random_run
62
63
64
def test_tabularq_agent_training() -> None:
65
tabularq_run = learner.epsilon_greedy_search(
66
cyberbattlechain,
67
ep,
68
learner=tqa.QTabularLearner(ep, gamma=0.015, learning_rate=0.01, exploit_percentile=100),
69
episode_count=training_episode_count,
70
iteration_count=iteration_count,
71
epsilon=0.90,
72
epsilon_exponential_decay=5000,
73
epsilon_minimum=0.01,
74
verbosity=Verbosity.Quiet,
75
render=False,
76
plot_episodes_length=False,
77
title="Tabular Q-learning",
78
)
79
80
assert tabularq_run
81
82