Path: blob/main/notebooks/dql_active_directory.py
597 views
# ---1# jupyter:2# jupytext:3# formats: py:percent,ipynb4# text_representation:5# extension: .py6# format_name: percent7# format_version: '1.3'8# jupytext_version: 1.16.49# kernelspec:10# display_name: Python 3 (ipykernel)11# language: python12# name: python313# ---1415# %% [markdown]16# # DQL agent running on the Active Directory sample environment1718# %%19import logging, sys20import gymnasium as gym21import cyberbattle.agents.baseline.learner as learner22import cyberbattle.agents.baseline.agent_wrapper as w23import cyberbattle.agents.baseline.agent_dql as dqla24from cyberbattle.agents.baseline.agent_wrapper import ActionTrackingStateAugmentation, AgentWrapper, Verbosity2526logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")27# %matplotlib inline2829# %% tags=["parameters"]30ngyms = 931iteration_count = 10003233# %%34gymids = [f"ActiveDirectory-v{i}" for i in range(0, ngyms)]3536# %%37from typing import cast38from cyberbattle._env.cyberbattle_env import CyberBattleEnv3940envs = [cast(CyberBattleEnv, gym.make(gymid).unwrapped) for gymid in gymids]41map(lambda g: g.reset(seed=1), envs)42ep = w.EnvironmentBounds.of_identifiers(maximum_node_count=30, maximum_total_credentials=50, identifiers=envs[0].identifiers)4344# %%45# Evaluate the Deep Q-learning agent for each env using transfer learning46_l = dqla.DeepQLearnerPolicy(47ep=ep,48gamma=0.015,49replay_memory_size=10000,50target_update=5,51batch_size=512,52learning_rate=0.01, # torch default learning rate is 1e-253)54for i, env in enumerate(envs):55epsilon = (10 - i) / 1056# at least 1 runs and max 10 for the 10 envs57training_episode_count = 1 + (9 - i)58dqn_learning_run = learner.epsilon_greedy_search(59cyberbattle_gym_env=env,60environment_properties=ep,61learner=_l,62episode_count=training_episode_count,63iteration_count=iteration_count,64epsilon=epsilon,65epsilon_exponential_decay=50000,66epsilon_minimum=0.1,67verbosity=Verbosity.Quiet,68render=False,69plot_episodes_length=False,70title=f"DQL {i}",71)72_l = dqn_learning_run["learner"]7374# %%75tiny = cast(CyberBattleEnv, gym.make(f"ActiveDirectory-v{ngyms}"))76current_o, _ = tiny.reset()77tiny.reset(seed=1)78wrapped_env = AgentWrapper(tiny, ActionTrackingStateAugmentation(ep, current_o))79# Use the trained agent to run the steps one by one80max_steps = 100081# next action suggested by DQL agent82# h = []83for i in range(max_steps):84# run the suggested action85_, next_action, _ = _l.exploit(wrapped_env, current_o)86# h.append((tiny.get_explored_network_node_properties_bitmap_as_numpy(current_o), next_action))87if next_action is None:88print("No more learned moves")89break90current_o, _, is_done, _, _ = wrapped_env.step(next_action)91if is_done:92print("Finished simulation")93break94tiny.render()959697