Path: blob/main/cyberbattle/agents/baseline/baseline_test.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23# -*- coding: utf-8 -*-4"""Test training of baseline agents."""56import torch7import gymnasium as gym8import logging9import sys10import cyberbattle._env.cyberbattle_env as cyberbattle_env11from cyberbattle.agents.baseline.agent_wrapper import Verbosity12import cyberbattle.agents.baseline.agent_dql as dqla13import cyberbattle.agents.baseline.agent_wrapper as w14import cyberbattle.agents.baseline.learner as learner15import cyberbattle.agents.baseline.agent_tabularqlearning as tqa16from typing import cast1718logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")1920print(f"torch cuda available={torch.cuda.is_available()}")2122cyberbattlechain = cast(cyberbattle_env.CyberBattleEnv, gym.make("CyberBattleChain-v0", size=4, attacker_goal=cyberbattle_env.AttackerGoal(own_atleast_percent=1.0, reward=100)))2324ep = w.EnvironmentBounds.of_identifiers(maximum_total_credentials=10, maximum_node_count=10, identifiers=cyberbattlechain.identifiers)2526# %% {"tags": ["parameters"]}27training_episode_count = 228iteration_count = 5293031def test_agent_training() -> None:32dqn_learning_run = learner.epsilon_greedy_search(33cyberbattle_gym_env=cyberbattlechain,34environment_properties=ep,35learner=dqla.DeepQLearnerPolicy(ep=ep, gamma=0.015, replay_memory_size=10000, target_update=10, batch_size=512, learning_rate=0.01), # torch default is 1e-236episode_count=training_episode_count,37iteration_count=iteration_count,38epsilon=0.90,39render=False,40# epsilon_multdecay=0.75, # 0.999,41epsilon_exponential_decay=5000, # 1000042epsilon_minimum=0.10,43verbosity=Verbosity.Quiet,44title="DQL",45)46assert dqn_learning_run4748random_run = learner.epsilon_greedy_search(49cyberbattlechain,50ep,51learner=learner.RandomPolicy(),52episode_count=training_episode_count,53iteration_count=iteration_count,54epsilon=1.0, # purely random55render=False,56verbosity=Verbosity.Quiet,57title="Random search",58)5960assert random_run616263def test_tabularq_agent_training() -> None:64tabularq_run = learner.epsilon_greedy_search(65cyberbattlechain,66ep,67learner=tqa.QTabularLearner(ep, gamma=0.015, learning_rate=0.01, exploit_percentile=100),68episode_count=training_episode_count,69iteration_count=iteration_count,70epsilon=0.90,71epsilon_exponential_decay=5000,72epsilon_minimum=0.01,73verbosity=Verbosity.Quiet,74render=False,75plot_episodes_length=False,76title="Tabular Q-learning",77)7879assert tabularq_run808182