Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/test_stablebaseline3.py
597 views
1
'''Stable-baselines agent for CyberBattle Gym environment'''
2
3
import os
4
from typing import cast
5
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
6
from cyberbattle._env.flatten_wrapper import (
7
FlattenObservationWrapper,
8
FlattenActionWrapper,
9
)
10
from stable_baselines3.common.type_aliases import GymEnv
11
from stable_baselines3.common.env_checker import check_env
12
from stable_baselines3.a2c.a2c import A2C
13
14
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
15
16
def test_stablebaseline3(training_steps=3, eval_steps=10):
17
18
cybersinm_env = CyberBattleToyCtf(
19
maximum_node_count=12,
20
maximum_total_credentials=10,
21
observation_padding=True,
22
throws_on_invalid_actions=False,
23
)
24
25
flatten_action_env = FlattenActionWrapper(cybersinm_env)
26
27
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
28
"_credential_cache",
29
"_discovered_nodes",
30
"_explored_network",
31
])
32
33
env_as_gym = cast(GymEnv, flatten_obs_env)
34
35
check_env(flatten_obs_env)
36
37
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(training_steps)
38
model_a2c.save("a2c_trained_toyctf")
39
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
40
41
obs , _= env_as_gym.reset()
42
for i in range(eval_steps):
43
assert isinstance(obs, dict)
44
action, _states = model.predict(obs, deterministic=True)
45
obs, reward, done, truncated, info = flatten_obs_env.step(action)
46
47
flatten_obs_env.render()
48
flatten_obs_env.close()
49
50