Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/scripts/run.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""A sample run of the CyberBattle simulation"""
5
6
from typing import cast
7
import gymnasium as gym
8
import logging
9
import sys
10
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
11
12
13
def main() -> int:
14
"""Entry point if called as an executable"""
15
16
root = logging.getLogger()
17
root.setLevel(logging.DEBUG)
18
19
handler = logging.StreamHandler(sys.stdout)
20
handler.setLevel(logging.DEBUG)
21
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
22
handler.setFormatter(formatter)
23
root.addHandler(handler)
24
25
env = cast(CyberBattleEnv, gym.make("CyberBattleToyCtf-v0"))
26
27
# logging.info(env.action_space.sample())
28
# logging.info(env.observation_space.sample())
29
30
for i_episode in range(1):
31
observation, _ = env.reset()
32
action_mask = env.compute_action_mask()
33
total_reward = 0
34
for t in range(500):
35
# env.render()
36
37
# sample a valid action
38
action = env.sample_valid_action()
39
while not env.apply_mask(action, action_mask):
40
action = env.sample_valid_action()
41
42
print("action: " + str(action))
43
observation, reward, done, truncated, info = env.step(action)
44
action_mask = observation["action_mask"]
45
total_reward = total_reward + reward
46
# print(observation)
47
print("total_reward=" + str(total_reward))
48
if truncated:
49
print("Episode truncated after {} timesteps".format(t + 1))
50
break
51
if done:
52
print("Episode finished after {} timesteps".format(t + 1))
53
break
54
55
env.close()
56
return 0
57
58
59
if __name__ == "__main__":
60
main()
61
62