Path: blob/main/cyberbattle/_env/option_wrapper.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23from typing import NamedTuple45import gymnasium as gym6from gymnasium.spaces import Space, Discrete, Tuple7import numpy as onp8from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv91011class Env(NamedTuple):12observation_space: Space13action_space: Space141516def context_spaces(observation_space, action_space):17K = 3 # noqa: N80618N, L = action_space.spaces["local_vulnerability"].nvec # noqa: N80619N, N, R = action_space.spaces["remote_vulnerability"].nvec # noqa: N80620N, N, P, C = action_space.spaces["connect"].nvec # noqa: N80621return {22"kind": Env(observation_space, Discrete(K)),23"local_node_id": Env(Tuple((observation_space, Discrete(K))), Discrete(N)),24"local_vuln_id": Env(Tuple((observation_space, Discrete(N))), Discrete(L)),25"remote_node_id": Env(Tuple((observation_space, Discrete(K), Discrete(N))), Discrete(N)),26"remote_vuln_id": Env(Tuple((observation_space, Discrete(N), Discrete(N))), Discrete(R)),27"cred_id": Env(observation_space, Discrete(C)),28}293031class ContextWrapper(gym.Wrapper):32__kinds = ("local_vulnerability", "remote_vulnerability", "connect")3334def __init__(self, env: CyberBattleEnv, options):35super().__init__(env)36self.env = env37assert isinstance(options, dict) and set(options) == {38"kind",39"local_node_id",40"local_vuln_id",41"remote_node_id",42"remote_vuln_id",43"cred_id",44}45self._options = options46self._bounds = env.bounds47self._action_context = []4849def reset(self, **kwargs):50self._action_context = onp.full(5, -1, dtype=onp.int32)51self._observation, info = self.env.reset(**kwargs)52return self._observation, info5354def step(self, action=None):55obs = self._observation56kind = self._options["kind"](obs)57local_node_id = self._options["local_node_id"]((obs, kind))58if kind == 0:59local_vuln_id = self._options["local_vuln_id"]((obs, local_node_id))60a: Action = {"local_vulnerability": onp.array([local_node_id, local_vuln_id])}61else:62remote_node_id = self._options["remote_node_id"]((obs, kind, local_node_id))63if kind == 1:64remote_vuln_id = self._options["remote_vuln_id"]((obs, local_node_id, remote_node_id))65a = {"remote_vulnerability": onp.array([local_node_id, remote_node_id, remote_vuln_id])}66else:67cred_id = self._options["cred_id"](obs)68assert cred_id < obs["credential_cache_length"]69node_id, port_id = obs["credential_cache_matrix"][cred_id].astype("int32")70a = {"connect": onp.array([local_node_id, node_id, port_id, cred_id])}7172self._observation, reward, done, truncated, info = self.env.step(a)73return self._observation, reward, done, truncated, {**info, "action": a}747576# --- random option policies --------------------------------------------------------------------- #777879def pi_kind(s):80kinds = ("local_vulnerability", "remote_vulnerability", "connect")81masked = onp.array([i for i, k in enumerate(kinds) if onp.any(s["action_mask"][k])])82return onp.random.choice(masked)838485def pi_local_node_id(s):86s, k = s87if k == 0:88local_node_ids, _ = onp.argwhere(s["action_mask"]["local_vulnerability"]).T89elif k == 1:90local_node_ids, _, _ = onp.argwhere(s["action_mask"]["remote_vulnerability"]).T91else:92local_node_ids, _, _, _ = onp.argwhere(s["action_mask"]["connect"]).T93return onp.random.choice(local_node_ids)949596def pi_local_vuln_id(s):97s, local_node_id = s98local_node_ids, local_vuln_ids = onp.argwhere(s["action_mask"]["local_vulnerability"]).T99masked = local_vuln_ids[local_node_ids == local_node_id]100return onp.random.choice(masked)101102103def pi_remote_node_id(s):104s, k, local_node_id = s105assert k != 0106if k == 1:107local_node_ids, remote_node_ids, _ = onp.argwhere(s["action_mask"]["remote_vulnerability"]).T108else:109local_node_ids, remote_node_ids, _, _ = onp.argwhere(s["action_mask"]["connect"]).T110return onp.random.choice(remote_node_ids[local_node_ids == local_node_id])111112113def pi_remote_vuln_id(s):114s, local_node_id, remote_node_id = s115local_node_ids, remote_node_ids, remote_vuln_ids = onp.argwhere(s["action_mask"]["remote_vulnerability"]).T116mask = (local_node_ids == local_node_id) & (remote_node_ids == remote_node_id)117return onp.random.choice(remote_vuln_ids[mask])118119120def pi_cred_id(s):121return onp.random.choice(s["credential_cache_length"])122123124random_options = {125"kind": pi_kind,126"local_node_id": pi_local_node_id,127"local_vuln_id": pi_local_vuln_id,128"remote_node_id": pi_remote_node_id,129"remote_vuln_id": pi_remote_vuln_id,130"cred_id": pi_cred_id,131}132133134