Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/option_wrapper.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
from typing import NamedTuple
5
6
import gymnasium as gym
7
from gymnasium.spaces import Space, Discrete, Tuple
8
import numpy as onp
9
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv
10
11
12
class Env(NamedTuple):
13
observation_space: Space
14
action_space: Space
15
16
17
def context_spaces(observation_space, action_space):
18
K = 3 # noqa: N806
19
N, L = action_space.spaces["local_vulnerability"].nvec # noqa: N806
20
N, N, R = action_space.spaces["remote_vulnerability"].nvec # noqa: N806
21
N, N, P, C = action_space.spaces["connect"].nvec # noqa: N806
22
return {
23
"kind": Env(observation_space, Discrete(K)),
24
"local_node_id": Env(Tuple((observation_space, Discrete(K))), Discrete(N)),
25
"local_vuln_id": Env(Tuple((observation_space, Discrete(N))), Discrete(L)),
26
"remote_node_id": Env(Tuple((observation_space, Discrete(K), Discrete(N))), Discrete(N)),
27
"remote_vuln_id": Env(Tuple((observation_space, Discrete(N), Discrete(N))), Discrete(R)),
28
"cred_id": Env(observation_space, Discrete(C)),
29
}
30
31
32
class ContextWrapper(gym.Wrapper):
33
__kinds = ("local_vulnerability", "remote_vulnerability", "connect")
34
35
def __init__(self, env: CyberBattleEnv, options):
36
super().__init__(env)
37
self.env = env
38
assert isinstance(options, dict) and set(options) == {
39
"kind",
40
"local_node_id",
41
"local_vuln_id",
42
"remote_node_id",
43
"remote_vuln_id",
44
"cred_id",
45
}
46
self._options = options
47
self._bounds = env.bounds
48
self._action_context = []
49
50
def reset(self, **kwargs):
51
self._action_context = onp.full(5, -1, dtype=onp.int32)
52
self._observation, info = self.env.reset(**kwargs)
53
return self._observation, info
54
55
def step(self, action=None):
56
obs = self._observation
57
kind = self._options["kind"](obs)
58
local_node_id = self._options["local_node_id"]((obs, kind))
59
if kind == 0:
60
local_vuln_id = self._options["local_vuln_id"]((obs, local_node_id))
61
a: Action = {"local_vulnerability": onp.array([local_node_id, local_vuln_id])}
62
else:
63
remote_node_id = self._options["remote_node_id"]((obs, kind, local_node_id))
64
if kind == 1:
65
remote_vuln_id = self._options["remote_vuln_id"]((obs, local_node_id, remote_node_id))
66
a = {"remote_vulnerability": onp.array([local_node_id, remote_node_id, remote_vuln_id])}
67
else:
68
cred_id = self._options["cred_id"](obs)
69
assert cred_id < obs["credential_cache_length"]
70
node_id, port_id = obs["credential_cache_matrix"][cred_id].astype("int32")
71
a = {"connect": onp.array([local_node_id, node_id, port_id, cred_id])}
72
73
self._observation, reward, done, truncated, info = self.env.step(a)
74
return self._observation, reward, done, truncated, {**info, "action": a}
75
76
77
# --- random option policies --------------------------------------------------------------------- #
78
79
80
def pi_kind(s):
81
kinds = ("local_vulnerability", "remote_vulnerability", "connect")
82
masked = onp.array([i for i, k in enumerate(kinds) if onp.any(s["action_mask"][k])])
83
return onp.random.choice(masked)
84
85
86
def pi_local_node_id(s):
87
s, k = s
88
if k == 0:
89
local_node_ids, _ = onp.argwhere(s["action_mask"]["local_vulnerability"]).T
90
elif k == 1:
91
local_node_ids, _, _ = onp.argwhere(s["action_mask"]["remote_vulnerability"]).T
92
else:
93
local_node_ids, _, _, _ = onp.argwhere(s["action_mask"]["connect"]).T
94
return onp.random.choice(local_node_ids)
95
96
97
def pi_local_vuln_id(s):
98
s, local_node_id = s
99
local_node_ids, local_vuln_ids = onp.argwhere(s["action_mask"]["local_vulnerability"]).T
100
masked = local_vuln_ids[local_node_ids == local_node_id]
101
return onp.random.choice(masked)
102
103
104
def pi_remote_node_id(s):
105
s, k, local_node_id = s
106
assert k != 0
107
if k == 1:
108
local_node_ids, remote_node_ids, _ = onp.argwhere(s["action_mask"]["remote_vulnerability"]).T
109
else:
110
local_node_ids, remote_node_ids, _, _ = onp.argwhere(s["action_mask"]["connect"]).T
111
return onp.random.choice(remote_node_ids[local_node_ids == local_node_id])
112
113
114
def pi_remote_vuln_id(s):
115
s, local_node_id, remote_node_id = s
116
local_node_ids, remote_node_ids, remote_vuln_ids = onp.argwhere(s["action_mask"]["remote_vulnerability"]).T
117
mask = (local_node_ids == local_node_id) & (remote_node_ids == remote_node_id)
118
return onp.random.choice(remote_vuln_ids[mask])
119
120
121
def pi_cred_id(s):
122
return onp.random.choice(s["credential_cache_length"])
123
124
125
random_options = {
126
"kind": pi_kind,
127
"local_node_id": pi_local_node_id,
128
"local_vuln_id": pi_local_vuln_id,
129
"remote_node_id": pi_remote_node_id,
130
"remote_vuln_id": pi_remote_vuln_id,
131
"cred_id": pi_cred_id,
132
}
133
134