Path: blob/main/cyberbattle/_env/cyberbattle_env_test.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Test the CyberBattle Gym environment"""45from cyberbattle._env.option_wrapper import ContextWrapper, random_options6from cyberbattle._env.cyberbattle_env import AttackerGoal, CyberBattleEnv7import pytest8import gymnasium as gym9import numpy as np10from typing import cast111213def test_few_gym_iterations() -> None:14"""Run a few iterations of the gym environment"""15env = cast(CyberBattleEnv, gym.make("CyberBattleToyCtf-v0"))1617for _ in range(2):18env.reset()19action_mask = env.compute_action_mask()20assert action_mask21for t in range(12):22# env.render()2324# sample a valid action25action = env.sample_valid_action()2627observation, reward, done, truncated, info = env.step(action)28if truncated:29print("Episode truncated after {} timesteps".format(t + 1))30break3132if done:33print("Episode finished after {} timesteps".format(t + 1))34break3536env.close()37pass383940def test_step_after_done() -> None:41actions = [42{"local_vulnerability": np.array([0, 1])}, # done=False r=9.043{"remote_vulnerability": np.array([0, 1, 0])}, # done=False r=4.044{"connect": np.array([0, 1, 2, 0])}, # done=False r=100.045{"local_vulnerability": np.array([1, 3])}, # done=False r=9.046{"connect": np.array([0, 2, 3, 1])}, # done=False r=100.047{"remote_vulnerability": np.array([1, 2, 1])}, # done=False r=6.048{"remote_vulnerability": np.array([1, 2, 0])}, # done=False r=6.049{"remote_vulnerability": np.array([2, 1, 1])}, # done=False r=2.050{"local_vulnerability": np.array([1, 0])}, # done=False r=6.051{"local_vulnerability": np.array([1, 1])}, # done=False r=0.052{"local_vulnerability": np.array([2, 1])}, # done=False r=6.053{"remote_vulnerability": np.array([2, 3, 0])}, # done=False r=4.054{"local_vulnerability": np.array([2, 4])}, # done=False r=9.055{"connect": np.array([0, 3, 2, 2])}, # done=False r=100.056{"local_vulnerability": np.array([3, 3])}, # done=False r=9.057{"local_vulnerability": np.array([3, 0])}, # done=False r=6.058{"remote_vulnerability": np.array([0, 4, 1])}, # done=False r=8.059{"local_vulnerability": np.array([3, 1])}, # done=False r=0.060{"connect": np.array([2, 4, 3, 3])}, # done=False r=100.061{"remote_vulnerability": np.array([1, 3, 1])}, # done=False r=2.062{"remote_vulnerability": np.array([1, 4, 0])}, # done=False r=6.063{"local_vulnerability": np.array([4, 1])}, # done=False r=6.064{"remote_vulnerability": np.array([0, 5, 0])}, # done=False r=4.065{"local_vulnerability": np.array([4, 4])}, # done=False r=9.066{"connect": np.array([3, 5, 2, 4])}, # done=False r=100.067{"remote_vulnerability": np.array([2, 5, 1])}, # done=False r=2.068{"local_vulnerability": np.array([5, 3])}, # done=False r=9.069{"connect": np.array([2, 6, 3, 5])}, # done=False r=100.070{"remote_vulnerability": np.array([4, 6, 1])}, # done=False r=6.071{"local_vulnerability": np.array([5, 0])}, # done=False r=6.072{"remote_vulnerability": np.array([4, 6, 0])}, # done=False r=6.073{"local_vulnerability": np.array([5, 1])}, # done=False r=0.074{"local_vulnerability": np.array([6, 1])}, # done=False r=6.075{"remote_vulnerability": np.array([6, 7, 0])}, # done=False r=4.076{"remote_vulnerability": np.array([0, 7, 1])}, # done=False r=2.077{"local_vulnerability": np.array([6, 4])}, # done=False r=9.078{"connect": np.array([4, 7, 2, 6])}, # done=False r=100.079{"local_vulnerability": np.array([7, 3])}, # done=False r=9.080{"connect": np.array([0, 8, 3, 7])}, # done=False r=100.081{"remote_vulnerability": np.array([0, 8, 0])}, # done=False r=6.082{"local_vulnerability": np.array([7, 0])}, # done=False r=6.083{"local_vulnerability": np.array([8, 4])}, # done=False r=9.084{"remote_vulnerability": np.array([3, 9, 1])}, # done=False r=2.085{"connect": np.array([3, 9, 2, 8])}, # done=False r=100.086{"remote_vulnerability": np.array([4, 9, 0])}, # done=False r=2.087{"local_vulnerability": np.array([9, 0])}, # done=False r=6.088{"remote_vulnerability": np.array([3, 8, 1])}, # done=False r=6.089{"remote_vulnerability": np.array([6, 10, 0])}, # done=False r=6.090{"local_vulnerability": np.array([9, 1])}, # done=False r=0.091{"local_vulnerability": np.array([9, 3])}, # done=False r=9.092{"remote_vulnerability": np.array([8, 10, 1])}, # done=False r=8.093{"local_vulnerability": np.array([7, 1])}, # done=False r=0.094{"connect": np.array([8, 10, 3, 9])}, # done=False r=100.095{"local_vulnerability": np.array([10, 4])}, # done=False r=9.096{"local_vulnerability": np.array([8, 1])}, # done=False r=6.097{"connect": np.array([7, 11, 2, 10])}, # done=True r=5000.098# this is one too many (after done)99{"connect": np.array([10, 5, 2, 4])},100]101102env = gym.make(103"CyberBattleChain-v0",104size=10,105attacker_goal=AttackerGoal(own_atleast_percent=1.0),106)107env.reset()108for a in actions[:-1]:109observation, reward, done, truncated, info = env.step(a)110print(f"{a}, # done={done} truncated={truncated} r={reward}")111112with pytest.raises(RuntimeError, match=r"new episode must be started with env\.reset\(\)"):113env.step(actions[-1])114115116def test_option_wrapper():117env = gym.make("CyberBattleChain-v0", size=10, attacker_goal=AttackerGoal(reward=4000))118env = ContextWrapper(cast(CyberBattleEnv, env), options=random_options)119120s, _ = env.reset()121for t in range(4):122s, r, done, truncated, info = env.step()123if r > 0:124print(r, done, info["action"])125if done:126break127128129