Path: blob/main/cyberbattle/_env/flatten_wrapper.py
597 views
"""Wrappers used to flatten action and observation spaces1for CyberBattleEnv gym environment.2"""34from collections import OrderedDict5from sqlite3 import NotSupportedError6from gymnasium import Env, spaces7import numpy as np8from gymnasium.core import ObservationWrapper, ActionWrapper910from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv111213class FlattenObservationWrapper(ObservationWrapper):14"""15Flatten all nested dictionaries and tuples from the16observation space of a CyberBattleSim environment.17The resulting observation space is a dictionary containing only18subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.19"""2021def flatten_multibinary_space(self, space: spaces.Space):22if isinstance(space, spaces.MultiBinary):23if type(space.n) in [tuple, list, np.ndarray]:24flatten_dim = np.multiply.reduce(space.n)25flatten_space = spaces.MultiBinary(flatten_dim)26print(f"// MultiBinary flattened from {space.n} -> {flatten_space.n} - dtype: {space.dtype} -> {flatten_space.dtype}")27return flatten_space28else:29print(f"// MultiBinary already flat: {space.n}")30return space31else:32return space3334def flatten_multidiscrete_space(self, space: spaces.Space):35if isinstance(space, spaces.MultiDiscrete):36if type(space.nvec) in [tuple, list, np.ndarray]:37flatten_space = spaces.MultiDiscrete(space.nvec.flatten())38print(f"// MultiDiscrete flattened from {space.nvec} -> {flatten_space.nvec}")39return flatten_space40else:41print(f"// MultiDiscrete already flat: {space.nvec}")42return space4344def __init__(self, env: Env, ignore_fields=["action_mask"]):45ObservationWrapper.__init__(self, env)46self.env = env47self.ignore_fields = ignore_fields48if isinstance(env.observation_space, spaces.Dict):49space_dict = OrderedDict({})50for key, space in env.observation_space.spaces.items():51if key in ignore_fields:52print("Filtering out field", key)53elif isinstance(space, spaces.Dict):54for subkey, subspace in space.items():55space_dict[f"{key}_{subkey}"] = self.flatten_multibinary_space(subspace)56elif isinstance(space, spaces.Tuple):57for i, subspace in enumerate(space.spaces):58space_dict[f"{key}_{i}"] = self.flatten_multibinary_space(subspace)59elif isinstance(space, spaces.MultiBinary):60space_dict[key] = self.flatten_multibinary_space(space)61elif isinstance(space, spaces.Discrete):62space_dict[key] = space63elif isinstance(space, spaces.MultiDiscrete):64space_dict[key] = self.flatten_multidiscrete_space(space)65else:66raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")6768self.observation_space = spaces.Dict(space_dict)6970def flatten_multibinary_observation(self, space, o):71if isinstance(space, spaces.MultiBinary) and isinstance(space.n, tuple) and len(space.n) > 1:72flatten_dim = np.multiply.reduce(space.n)73# print(f"dtype: {o.dtype} shape: {o.shape} -> {flatten_dim}")74reshaped = o.reshape(flatten_dim)75# print(f"reshaped: {reshaped.dtype} shape: {reshaped.shape}")76return reshaped77else:78return o7980def flatten_multidiscrete_observation(self, space, o):81if isinstance(space, spaces.MultiDiscrete):82return o.flatten()83else:84return o8586def observation(self, observation):87if isinstance(self.env.observation_space, spaces.Dict):88o = OrderedDict({})89for key, space in self.env.observation_space.spaces.items():90value = observation[key]91if key in self.ignore_fields:92continue93elif isinstance(space, spaces.Dict):94for subkey, subspace in space.items():95o[f"{key}_{subkey}"] = self.flatten_multibinary_observation(subspace, value[subkey])96elif isinstance(space, spaces.Tuple):97for i, subspace in enumerate(space.spaces):98o[f"{key}_{i}"] = self.flatten_multibinary_observation(subspace, value[i])99elif isinstance(space, spaces.MultiBinary):100o[key] = self.flatten_multibinary_observation(space, value)101elif isinstance(space, spaces.Discrete):102o[key] = value103elif isinstance(space, spaces.MultiDiscrete):104o[key] = self.flatten_multidiscrete_observation(space, value)105else:106raise NotImplementedError(f"Case not handled: {key} - type {type(space)}")107108return o109else:110return observation111112def step(self, action):113"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""114observation, reward, terminated, truncated, info = self.env.step(action)115return self.observation(observation), reward, terminated, truncated, info116117118class FlattenActionWrapper(ActionWrapper):119"""120Flatten all nested dictionaries and tuples from the121action space of a CyberBattleSim environment.122The resulting action space is a dictionary containing only123subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.124"""125126def __init__(self, env: CyberBattleEnv):127ActionWrapper.__init__(self, env)128self.env = env129130self.action_space = spaces.MultiDiscrete(131np.array([132# connect, local vulnerabilities, remote vulnerabilities1331 + env.bounds.local_attacks_count + env.bounds.remote_attacks_count,134# source node135env.bounds.maximum_node_count,136# target node137env.bounds.maximum_node_count,138# target port (for connect action only)139env.bounds.port_count,140# target port (credentials used, for connect action only)141env.bounds.maximum_total_credentials,142], dtype=np.int32)143)144145def action(self, action: np.ndarray) -> Action:146action_type = action[0]147if action_type == 0:148return {"connect": action[1:5]}149150action_type -= 1151if action_type < self.env.bounds.local_attacks_count:152return {"local_vulnerability": np.array([action[1], action_type])}153154action_type -= self.env.bounds.local_attacks_count155if action_type < self.env.bounds.remote_attacks_count:156return {"remote_vulnerability": np.array([action[1], action[2], action_type])}157158raise NotSupportedError(f"Unsupported action: {action}")159160def reverse_action(self, action):161raise NotImplementedError162163164