Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/notebooks/stable-baselines-agent.py
597 views
1
'''Stable-baselines agent for CyberBattle Gym environment'''
2
3
# %%
4
from typing import cast
5
from cyberbattle._env.cyberbattle_toyctf import CyberBattleToyCtf
6
import logging
7
import sys
8
from stable_baselines3.a2c.a2c import A2C
9
from stable_baselines3.ppo.ppo import PPO
10
from cyberbattle._env.flatten_wrapper import (
11
FlattenObservationWrapper,
12
FlattenActionWrapper,
13
)
14
import os
15
from stable_baselines3.common.type_aliases import GymEnv
16
from stable_baselines3.common.env_checker import check_env
17
18
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
19
retrain = ["a2c", "ppo"]
20
21
22
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
23
24
# %%
25
env = CyberBattleToyCtf(
26
maximum_node_count=12,
27
maximum_total_credentials=10,
28
observation_padding=True,
29
throws_on_invalid_actions=False,
30
)
31
32
# %%
33
flatten_action_env = FlattenActionWrapper(env)
34
35
# %%
36
flatten_obs_env = FlattenObservationWrapper(flatten_action_env, ignore_fields=[
37
# DummySpace
38
"_credential_cache",
39
"_discovered_nodes",
40
"_explored_network",
41
])
42
43
#%%
44
env_as_gym = cast(GymEnv, flatten_obs_env)
45
46
#%%
47
o, _ = env_as_gym.reset()
48
print(o)
49
50
#%%
51
check_env(flatten_obs_env)
52
53
54
# %%
55
if "a2c" in retrain:
56
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(10000)
57
model_a2c.save("a2c_trained_toyctf")
58
59
60
# %%
61
if "ppo" in retrain:
62
model_ppo = PPO("MultiInputPolicy", env_as_gym).learn(100)
63
model_ppo.save("ppo_trained_toyctf")
64
65
66
# %%
67
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
68
# model = PPO("MultiInputPolicy", env2).load('ppo_trained_toyctf')
69
70
71
# %%
72
obs , _= env_as_gym.reset()
73
74
75
# %%
76
for i in range(1000):
77
assert isinstance(obs, dict)
78
action, _states = model.predict(obs, deterministic=True)
79
obs, reward, done, truncated, info = flatten_obs_env.step(action)
80
81
flatten_obs_env.render()
82
flatten_obs_env.close()
83
84
# %%
85
86