Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/flatten_wrapper.py
597 views
1
"""Wrappers used to flatten action and observation spaces
2
for CyberBattleEnv gym environment.
3
"""
4
5
from collections import OrderedDict
6
from sqlite3 import NotSupportedError
7
from gymnasium import Env, spaces
8
import numpy as np
9
from gymnasium.core import ObservationWrapper, ActionWrapper
10
11
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv
12
13
14
class FlattenObservationWrapper(ObservationWrapper):
15
"""
16
Flatten all nested dictionaries and tuples from the
17
observation space of a CyberBattleSim environment.
18
The resulting observation space is a dictionary containing only
19
subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.
20
"""
21
22
def flatten_multibinary_space(self, space: spaces.Space):
23
if isinstance(space, spaces.MultiBinary):
24
if type(space.n) in [tuple, list, np.ndarray]:
25
flatten_dim = np.multiply.reduce(space.n)
26
flatten_space = spaces.MultiBinary(flatten_dim)
27
print(f"// MultiBinary flattened from {space.n} -> {flatten_space.n} - dtype: {space.dtype} -> {flatten_space.dtype}")
28
return flatten_space
29
else:
30
print(f"// MultiBinary already flat: {space.n}")
31
return space
32
else:
33
return space
34
35
def flatten_multidiscrete_space(self, space: spaces.Space):
36
if isinstance(space, spaces.MultiDiscrete):
37
if type(space.nvec) in [tuple, list, np.ndarray]:
38
flatten_space = spaces.MultiDiscrete(space.nvec.flatten())
39
print(f"// MultiDiscrete flattened from {space.nvec} -> {flatten_space.nvec}")
40
return flatten_space
41
else:
42
print(f"// MultiDiscrete already flat: {space.nvec}")
43
return space
44
45
def __init__(self, env: Env, ignore_fields=["action_mask"]):
46
ObservationWrapper.__init__(self, env)
47
self.env = env
48
self.ignore_fields = ignore_fields
49
if isinstance(env.observation_space, spaces.Dict):
50
space_dict = OrderedDict({})
51
for key, space in env.observation_space.spaces.items():
52
if key in ignore_fields:
53
print("Filtering out field", key)
54
elif isinstance(space, spaces.Dict):
55
for subkey, subspace in space.items():
56
space_dict[f"{key}_{subkey}"] = self.flatten_multibinary_space(subspace)
57
elif isinstance(space, spaces.Tuple):
58
for i, subspace in enumerate(space.spaces):
59
space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)
60
elif isinstance(space, spaces.MultiBinary):
61
space_dict[key] = self.flatten_multibinary_space(space)
62
elif isinstance(space, spaces.Discrete):
63
space_dict[key] = space
64
elif isinstance(space, spaces.MultiDiscrete):
65
space_dict[key] = self.flatten_multidiscrete_space(space)
66
else:
67
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
68
69
self.observation_space = spaces.Dict(space_dict)
70
71
def flatten_multibinary_observation(self, space, o):
72
if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:
73
flatten_dim = np.multiply.reduce(space.n)
74
# print(f"dtype: {o.dtype} shape: {o.shape} -> {flatten_dim}")
75
reshaped = o.reshape(flatten_dim)
76
# print(f"reshaped: {reshaped.dtype} shape: {reshaped.shape}")
77
return reshaped
78
else:
79
return o
80
81
def flatten_multidiscrete_observation(self, space, o):
82
if isinstance(space, spaces.MultiDiscrete):
83
return o.flatten()
84
else:
85
return o
86
87
def observation(self, observation):
88
if isinstance(self.env.observation_space, spaces.Dict):
89
o = OrderedDict({})
90
for key, space in self.env.observation_space.spaces.items():
91
value = observation[key]
92
if key in self.ignore_fields:
93
continue
94
elif isinstance(space, spaces.Dict):
95
for subkey, subspace in space.items():
96
o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])
97
elif isinstance(space, spaces.Tuple):
98
for i, subspace in enumerate(space.spaces):
99
o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])
100
elif isinstance(space, spaces.MultiBinary):
101
o[key] = self.flatten_multibinary_observation(space, value)
102
elif isinstance(space, spaces.Discrete):
103
o[key] = value
104
elif isinstance(space, spaces.MultiDiscrete):
105
o[key] = self.flatten_multidiscrete_observation(space, value)
106
else:
107
raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")
108
109
return o
110
else:
111
return observation
112
113
def step(self, action):
114
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
115
observation, reward, terminated, truncated, info = self.env.step(action)
116
return self.observation(observation), reward, terminated, truncated, info
117
118
119
class FlattenActionWrapper(ActionWrapper):
120
"""
121
Flatten all nested dictionaries and tuples from the
122
action space of a CyberBattleSim environment.
123
The resulting action space is a dictionary containing only
124
subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.
125
"""
126
127
def __init__(self, env: CyberBattleEnv):
128
ActionWrapper.__init__(self, env)
129
self.env = env
130
131
self.action_space = spaces.MultiDiscrete(
132
np.array([
133
# connect, local vulnerabilities, remote vulnerabilities
134
1 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,
135
# source node
136
env.bounds.maximum_node_count,
137
# target node
138
env.bounds.maximum_node_count,
139
# target port (for connect action only)
140
env.bounds.port_count,
141
# target port (credentials used, for connect action only)
142
env.bounds.maximum_total_credentials,
143
], dtype=np.int32)
144
)
145
146
def action(self, action: np.ndarray) -> Action:
147
action_type = action[0]
148
if action_type == 0:
149
return {"connect": action[1:5]}
150
151
action_type -= 1
152
if action_type < self.env.bounds.local_attacks_count:
153
return {"local_vulnerability": np.array([action[1], action_type])}
154
155
action_type -= self.env.bounds.local_attacks_count
156
if action_type < self.env.bounds.remote_attacks_count:
157
return {"remote_vulnerability": np.array([action[1], action[2], action_type])}
158
159
raise NotSupportedError(f"Unsupported action: {action}")
160
161
def reverse_action(self, action):
162
raise NotImplementedError
163
164