Path: blob/main/notebooks/stable-baselines-agent.py
597 views
'''Stable-baselines agent for CyberBattle Gym environment'''12# %%3from typing import cast4from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf5import logging6import sys7from stable_baselines3.a2c.a2c import A2C8from stable_baselines3.ppo.ppo import PPO9from cyberbattle._env.flatten_wrapper import (10FlattenObservationWrapper,11FlattenActionWrapper,12)13import os14from stable_baselines3.common.type_aliases import GymEnv15from stable_baselines3.common.env_checker import check_env1617os.environ["CUDA_LAUNCH_BLOCKING"] = "1"18retrain = ["a2c", "ppo"]192021logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")2223# %%24env = CyberBattleToyCtf(25maximum_node_count=12,26maximum_total_credentials=10,27observation_padding=True,28throws_on_invalid_actions=False,29)3031# %%32flatten_action_env = FlattenActionWrapper(env)3334# %%35flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[36# DummySpace37"_credential_cache",38"_discovered_nodes",39"_explored_network",40])4142#%%43env_as_gym = cast(GymEnv, flatten_obs_env)4445#%%46o, _ = env_as_gym.reset()47print(o)4849#%%50check_env(flatten_obs_env)515253# %%54if "a2c" in retrain:55model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(10000)56model_a2c.save("a2c_trained_toyctf")575859# %%60if "ppo" in retrain:61model_ppo = PPO("MultiInputPolicy", env_as_gym).learn(100)62model_ppo.save("ppo_trained_toyctf")636465# %%66model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")67# model = PPO("MultiInputPolicy", env2).load('ppo_trained_toyctf')686970# %%71obs , _= env_as_gym.reset()727374# %%75for i in range(1000):76assert isinstance(obs, dict)77action, _states = model.predict(obs, deterministic=True)78obs, reward, done, truncated, info = flatten_obs_env.step(action)7980flatten_obs_env.render()81flatten_obs_env.close()8283# %%848586