Path: blob/main/notebooks/notebook_dql_debug_tiny.py
597 views
# ---1# jupyter:2# jupytext:3# cell_metadata_filter: tags,title,-all4# cell_metadata_json: true5# formats: py:percent,ipynb6# text_representation:7# extension: .py8# format_name: percent9# format_version: '1.3'10# jupytext_version: 1.16.411# kernelspec:12# display_name: Python 3 (ipykernel)13# language: python14# name: python315# ---1617# %% {"tags": []}18# Copyright (c) Microsoft Corporation.19# Licensed under the MIT License.2021"""Notebook used for debugging purpose to train the22the DQL agent and then run it one step at a time.23"""2425# pylint: disable=invalid-name26# %matplotlib inline2728# %% {"tags": []}29import sys30import logging31import gymnasium as gym32import cyberbattle.agents.baseline.learner as learner33import cyberbattle.agents.baseline.agent_wrapper as w34import cyberbattle.agents.baseline.agent_dql as dqla35from cyberbattle.agents.baseline.agent_wrapper import ActionTrackingStateAugmentation, AgentWrapper, Verbosity3637logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")3839# %% {"tags": ["parameters"]}40gymid = "CyberBattleTiny-v0"41iteration_count = 15042training_episode_count = 10434445# %% {"tags": []}46# Load the gym environment4748_gym_env = gym.make(gymid)4950from typing import cast51from cyberbattle._env.cyberbattle_env import CyberBattleEnv5253ctf_env = cast(CyberBattleEnv, _gym_env)5455ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=12, maximum_total_credentials=10, identifiers=ctf_env.identifiers)5657# %% {"tags": []}58# Evaluate the Deep Q-learning agent59dqn_learning_run = learner.epsilon_greedy_search(60cyberbattle_gym_env=ctf_env,61environment_properties=ep,62learner=dqla.DeepQLearnerPolicy(63ep=ep,64gamma=0.015,65replay_memory_size=10000,66target_update=5,67batch_size=512,68learning_rate=0.01, # torch default learning rate is 1e-269),70episode_count=training_episode_count,71iteration_count=iteration_count,72epsilon=0.90,73epsilon_exponential_decay=5000,74epsilon_minimum=0.10,75verbosity=Verbosity.Quiet,76render=False,77plot_episodes_length=False,78title="DQL",79)8081# %% {"tags": []}82# initialize the environment8384current_o, _ = ctf_env.reset()85wrapped_env = AgentWrapper(ctf_env, ActionTrackingStateAugmentation(ep, current_o))86_l = dqn_learning_run["learner"]8788# %% {"tags": []}89# Use the trained agent to run the steps one by one9091max_steps = 109293# next action suggested by DQL agent94h = []95for i in range(max_steps):96# run the suggested action97_, next_action, _ = _l.exploit(wrapped_env, current_o)98h.append((ctf_env.get_explored_network_node_properties_bitmap_as_numpy(current_o), next_action))99print(h[-1])100if next_action is None:101break102current_o, _, _, _, _ = wrapped_env.step(next_action)103104print(f"len: {len(h)}")105106# %% {"tags": []}107ctf_env.render()108109110