Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/notebooks/notebook_dql_debug_tiny.py
597 views
1
# ---
2
# jupyter:
3
# jupytext:
4
# cell_metadata_filter: tags,title,-all
5
# cell_metadata_json: true
6
# formats: py:percent,ipynb
7
# text_representation:
8
# extension: .py
9
# format_name: percent
10
# format_version: '1.3'
11
# jupytext_version: 1.16.4
12
# kernelspec:
13
# display_name: Python 3 (ipykernel)
14
# language: python
15
# name: python3
16
# ---
17
18
# %% {"tags": []}
19
# Copyright (c) Microsoft Corporation.
20
# Licensed under the MIT License.
21
22
"""Notebook used for debugging purpose to train the
23
the DQL agent and then run it one step at a time.
24
"""
25
26
# pylint: disable=invalid-name
27
# %matplotlib inline
28
29
# %% {"tags": []}
30
import sys
31
import logging
32
import gymnasium as gym
33
import cyberbattle.agents.baseline.learner as learner
34
import cyberbattle.agents.baseline.agent_wrapper as w
35
import cyberbattle.agents.baseline.agent_dql as dqla
36
from cyberbattle.agents.baseline.agent_wrapper import ActionTrackingStateAugmentation, AgentWrapper, Verbosity
37
38
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
39
40
# %% {"tags": ["parameters"]}
41
gymid = "CyberBattleTiny-v0"
42
iteration_count = 150
43
training_episode_count = 10
44
45
46
# %% {"tags": []}
47
# Load the gym environment
48
49
_gym_env = gym.make(gymid)
50
51
from typing import cast
52
from cyberbattle._env.cyberbattle_env import CyberBattleEnv
53
54
ctf_env = cast(CyberBattleEnv, _gym_env)
55
56
ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=12, maximum_total_credentials=10, identifiers=ctf_env.identifiers)
57
58
# %% {"tags": []}
59
# Evaluate the Deep Q-learning agent
60
dqn_learning_run = learner.epsilon_greedy_search(
61
cyberbattle_gym_env=ctf_env,
62
environment_properties=ep,
63
learner=dqla.DeepQLearnerPolicy(
64
ep=ep,
65
gamma=0.015,
66
replay_memory_size=10000,
67
target_update=5,
68
batch_size=512,
69
learning_rate=0.01, # torch default learning rate is 1e-2
70
),
71
episode_count=training_episode_count,
72
iteration_count=iteration_count,
73
epsilon=0.90,
74
epsilon_exponential_decay=5000,
75
epsilon_minimum=0.10,
76
verbosity=Verbosity.Quiet,
77
render=False,
78
plot_episodes_length=False,
79
title="DQL",
80
)
81
82
# %% {"tags": []}
83
# initialize the environment
84
85
current_o, _ = ctf_env.reset()
86
wrapped_env = AgentWrapper(ctf_env, ActionTrackingStateAugmentation(ep, current_o))
87
_l = dqn_learning_run["learner"]
88
89
# %% {"tags": []}
90
# Use the trained agent to run the steps one by one
91
92
max_steps = 10
93
94
# next action suggested by DQL agent
95
h = []
96
for i in range(max_steps):
97
# run the suggested action
98
_, next_action, _ = _l.exploit(wrapped_env, current_o)
99
h.append((ctf_env.get_explored_network_node_properties_bitmap_as_numpy(current_o), next_action))
100
print(h[-1])
101
if next_action is None:
102
break
103
current_o, _, _, _, _ = wrapped_env.step(next_action)
104
105
print(f"len: {len(h)}")
106
107
# %% {"tags": []}
108
ctf_env.render()
109
110