Path: blob/main/cyberbattle/agents/baseline/test_stablebaseline3.py
597 views
'''Stable-baselines agent for CyberBattle Gym environment'''12import os3from typing import cast4from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf5from cyberbattle._env.flatten_wrapper import (6FlattenObservationWrapper,7FlattenActionWrapper,8)9from stable_baselines3.common.type_aliases import GymEnv10from stable_baselines3.common.env_checker import check_env11from stable_baselines3.a2c.a2c import A2C1213os.environ["CUDA_LAUNCH_BLOCKING"] = "1"1415def test_stablebaseline3(training_steps=3, eval_steps=10):1617cybersinm_env = CyberBattleToyCtf(18maximum_node_count=12,19maximum_total_credentials=10,20observation_padding=True,21throws_on_invalid_actions=False,22)2324flatten_action_env = FlattenActionWrapper(cybersinm_env)2526flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[27"_credential_cache",28"_discovered_nodes",29"_explored_network",30])3132env_as_gym = cast(GymEnv, flatten_obs_env)3334check_env(flatten_obs_env)3536model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(training_steps)37model_a2c.save("a2c_trained_toyctf")38model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")3940obs , _= env_as_gym.reset()41for i in range(eval_steps):42assert isinstance(obs, dict)43action, _states = model.predict(obs, deterministic=True)44obs, reward, done, truncated, info = flatten_obs_env.step(action)4546flatten_obs_env.render()47flatten_obs_env.close()484950