Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/cyberbattle_env_test.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Test the CyberBattle Gym environment"""
5
6
from cyberbattle._env.option_wrapper import ContextWrapper, random_options
7
from cyberbattle._env.cyberbattle_env import AttackerGoal, CyberBattleEnv
8
import pytest
9
import gymnasium as gym
10
import numpy as np
11
from typing import cast
12
13
14
def test_few_gym_iterations() -> None:
15
"""Run a few iterations of the gym environment"""
16
env = cast(CyberBattleEnv, gym.make("CyberBattleToyCtf-v0"))
17
18
for _ in range(2):
19
env.reset()
20
action_mask = env.compute_action_mask()
21
assert action_mask
22
for t in range(12):
23
# env.render()
24
25
# sample a valid action
26
action = env.sample_valid_action()
27
28
observation, reward, done, truncated, info = env.step(action)
29
if truncated:
30
print("Episode truncated after {} timesteps".format(t + 1))
31
break
32
33
if done:
34
print("Episode finished after {} timesteps".format(t + 1))
35
break
36
37
env.close()
38
pass
39
40
41
def test_step_after_done() -> None:
42
actions = [
43
{"local_vulnerability": np.array([0, 1])}, # done=False r=9.0
44
{"remote_vulnerability": np.array([0, 1, 0])}, # done=False r=4.0
45
{"connect": np.array([0, 1, 2, 0])}, # done=False r=100.0
46
{"local_vulnerability": np.array([1, 3])}, # done=False r=9.0
47
{"connect": np.array([0, 2, 3, 1])}, # done=False r=100.0
48
{"remote_vulnerability": np.array([1, 2, 1])}, # done=False r=6.0
49
{"remote_vulnerability": np.array([1, 2, 0])}, # done=False r=6.0
50
{"remote_vulnerability": np.array([2, 1, 1])}, # done=False r=2.0
51
{"local_vulnerability": np.array([1, 0])}, # done=False r=6.0
52
{"local_vulnerability": np.array([1, 1])}, # done=False r=0.0
53
{"local_vulnerability": np.array([2, 1])}, # done=False r=6.0
54
{"remote_vulnerability": np.array([2, 3, 0])}, # done=False r=4.0
55
{"local_vulnerability": np.array([2, 4])}, # done=False r=9.0
56
{"connect": np.array([0, 3, 2, 2])}, # done=False r=100.0
57
{"local_vulnerability": np.array([3, 3])}, # done=False r=9.0
58
{"local_vulnerability": np.array([3, 0])}, # done=False r=6.0
59
{"remote_vulnerability": np.array([0, 4, 1])}, # done=False r=8.0
60
{"local_vulnerability": np.array([3, 1])}, # done=False r=0.0
61
{"connect": np.array([2, 4, 3, 3])}, # done=False r=100.0
62
{"remote_vulnerability": np.array([1, 3, 1])}, # done=False r=2.0
63
{"remote_vulnerability": np.array([1, 4, 0])}, # done=False r=6.0
64
{"local_vulnerability": np.array([4, 1])}, # done=False r=6.0
65
{"remote_vulnerability": np.array([0, 5, 0])}, # done=False r=4.0
66
{"local_vulnerability": np.array([4, 4])}, # done=False r=9.0
67
{"connect": np.array([3, 5, 2, 4])}, # done=False r=100.0
68
{"remote_vulnerability": np.array([2, 5, 1])}, # done=False r=2.0
69
{"local_vulnerability": np.array([5, 3])}, # done=False r=9.0
70
{"connect": np.array([2, 6, 3, 5])}, # done=False r=100.0
71
{"remote_vulnerability": np.array([4, 6, 1])}, # done=False r=6.0
72
{"local_vulnerability": np.array([5, 0])}, # done=False r=6.0
73
{"remote_vulnerability": np.array([4, 6, 0])}, # done=False r=6.0
74
{"local_vulnerability": np.array([5, 1])}, # done=False r=0.0
75
{"local_vulnerability": np.array([6, 1])}, # done=False r=6.0
76
{"remote_vulnerability": np.array([6, 7, 0])}, # done=False r=4.0
77
{"remote_vulnerability": np.array([0, 7, 1])}, # done=False r=2.0
78
{"local_vulnerability": np.array([6, 4])}, # done=False r=9.0
79
{"connect": np.array([4, 7, 2, 6])}, # done=False r=100.0
80
{"local_vulnerability": np.array([7, 3])}, # done=False r=9.0
81
{"connect": np.array([0, 8, 3, 7])}, # done=False r=100.0
82
{"remote_vulnerability": np.array([0, 8, 0])}, # done=False r=6.0
83
{"local_vulnerability": np.array([7, 0])}, # done=False r=6.0
84
{"local_vulnerability": np.array([8, 4])}, # done=False r=9.0
85
{"remote_vulnerability": np.array([3, 9, 1])}, # done=False r=2.0
86
{"connect": np.array([3, 9, 2, 8])}, # done=False r=100.0
87
{"remote_vulnerability": np.array([4, 9, 0])}, # done=False r=2.0
88
{"local_vulnerability": np.array([9, 0])}, # done=False r=6.0
89
{"remote_vulnerability": np.array([3, 8, 1])}, # done=False r=6.0
90
{"remote_vulnerability": np.array([6, 10, 0])}, # done=False r=6.0
91
{"local_vulnerability": np.array([9, 1])}, # done=False r=0.0
92
{"local_vulnerability": np.array([9, 3])}, # done=False r=9.0
93
{"remote_vulnerability": np.array([8, 10, 1])}, # done=False r=8.0
94
{"local_vulnerability": np.array([7, 1])}, # done=False r=0.0
95
{"connect": np.array([8, 10, 3, 9])}, # done=False r=100.0
96
{"local_vulnerability": np.array([10, 4])}, # done=False r=9.0
97
{"local_vulnerability": np.array([8, 1])}, # done=False r=6.0
98
{"connect": np.array([7, 11, 2, 10])}, # done=True r=5000.0
99
# this is one too many (after done)
100
{"connect": np.array([10, 5, 2, 4])},
101
]
102
103
env = gym.make(
104
"CyberBattleChain-v0",
105
size=10,
106
attacker_goal=AttackerGoal(own_atleast_percent=1.0),
107
)
108
env.reset()
109
for a in actions[:-1]:
110
observation, reward, done, truncated, info = env.step(a)
111
print(f"{a}, # done={done} truncated={truncated} r={reward}")
112
113
with pytest.raises(RuntimeError, match=r"new episode must be started with env\.reset\(\)"):
114
env.step(actions[-1])
115
116
117
def test_option_wrapper():
118
env = gym.make("CyberBattleChain-v0", size=10, attacker_goal=AttackerGoal(reward=4000))
119
env = ContextWrapper(cast(CyberBattleEnv, env), options=random_options)
120
121
s, _ = env.reset()
122
for t in range(4):
123
s, r, done, truncated, info = env.step()
124
if r > 0:
125
print(r, done, info["action"])
126
if done:
127
break
128
129