Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/agent_dql.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
# Function DeepQLearnerPolicy.optimize_model:
5
# Copyright (c) 2017, Pytorch contributors
6
# All rights reserved.
7
# https://github.com/pytorch/tutorials/blob/master/LICENSE
8
9
"""Deep Q-learning agent applied to chain network (notebook)
10
This notebooks can be run directly from VSCode, to generate a
11
traditional Jupyter Notebook to open in your browser
12
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
13
14
Requirements:
15
Nvidia CUDA drivers for WSL2: https://docs.nvidia.com/cuda/wsl-user-guide/index.html
16
PyTorch
17
"""
18
19
# pylint: disable=invalid-name
20
21
# %% [markdown]
22
# # Chain network CyberBattle Gym played by a Deeo Q-learning agent
23
24
# %%
25
from numpy import ndarray
26
from cyberbattle._env import cyberbattle_env
27
import numpy as np
28
from typing import List, NamedTuple, Optional, Tuple, Union
29
import random
30
31
# deep learning packages
32
from torch import Tensor
33
import torch.nn.functional as F
34
import torch.optim as optim
35
import torch.nn as nn
36
import torch
37
import torch.cuda
38
from torch.nn.utils.clip_grad import clip_grad_norm_
39
40
from .learner import Learner
41
from .agent_wrapper import EnvironmentBounds
42
import cyberbattle.agents.baseline.agent_wrapper as w
43
from .agent_randomcredlookup import CredentialCacheExploiter
44
45
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
47
48
class CyberBattleStateActionModel:
49
"""Define an abstraction of the state and action space
50
for a CyberBattle environment, to be used to train a Q-function.
51
"""
52
53
def __init__(self, ep: EnvironmentBounds):
54
self.ep = ep
55
56
self.global_features = w.ConcatFeatures(
57
ep,
58
[
59
# w.Feature_discovered_node_count(ep),
60
# w.Feature_owned_node_count(ep),
61
w.Feature_discovered_notowned_node_count(ep, None)
62
# w.Feature_discovered_ports(ep),
63
# w.Feature_discovered_ports_counts(ep),
64
# w.Feature_discovered_ports_sliding(ep),
65
# w.Feature_discovered_credential_count(ep),
66
# w.Feature_discovered_nodeproperties_sliding(ep),
67
],
68
)
69
70
self.node_specific_features = w.ConcatFeatures(
71
ep,
72
[
73
# w.Feature_actions_tried_at_node(ep),
74
w.Feature_success_actions_at_node(ep),
75
w.Feature_failed_actions_at_node(ep),
76
w.Feature_active_node_properties(ep),
77
w.Feature_active_node_age(ep),
78
# w.Feature_active_node_id(ep)
79
],
80
)
81
82
self.state_space = w.ConcatFeatures(
83
ep,
84
self.global_features.feature_selection + self.node_specific_features.feature_selection,
85
)
86
87
self.action_space = w.AbstractAction(ep)
88
89
def get_state_astensor(self, state: w.StateAugmentation):
90
state_vector = self.state_space.get(state, node=None)
91
state_vector_float = np.array(state_vector, dtype=np.float32)
92
state_tensor = torch.from_numpy(state_vector_float).unsqueeze(0)
93
return state_tensor
94
95
def implement_action(
96
self,
97
wrapped_env: w.AgentWrapper,
98
actor_features: ndarray,
99
abstract_action: np.int32,
100
) -> Tuple[str, Optional[cyberbattle_env.Action], Optional[int]]:
101
"""Specialize an abstract model action into a CyberBattle gym action.
102
103
actor_features -- the desired features of the actor to use (source CyberBattle node)
104
abstract_action -- the desired type of attack (connect, local, remote).
105
106
Returns a gym environment implementing the desired attack at a node with the desired embedding.
107
"""
108
109
observation = wrapped_env.state.observation
110
111
# Pick source node at random (owned and with the desired feature encoding)
112
potential_source_nodes = [from_node for from_node in w.owned_nodes(observation) if np.all(actor_features == self.node_specific_features.get(wrapped_env.state, from_node))]
113
114
if len(potential_source_nodes) > 0:
115
source_node = np.random.choice(potential_source_nodes)
116
117
gym_action = self.action_space.specialize_to_gymaction(source_node, observation, np.int32(abstract_action))
118
119
if not gym_action:
120
return "exploit[undefined]->explore", None, None
121
122
elif wrapped_env.env.is_action_valid(gym_action, observation["action_mask"]):
123
return "exploit", gym_action, source_node
124
else:
125
return "exploit[invalid]->explore", None, None
126
else:
127
return "exploit[no_actor]->explore", None, None
128
129
130
# %%
131
132
# Deep Q-learning
133
134
135
class Transition(NamedTuple):
136
"""One taken transition and its outcome"""
137
138
state: Union[Tuple[Tensor], List[Tensor]]
139
action: Union[Tuple[Tensor], List[Tensor]]
140
next_state: Union[Tuple[Tensor], List[Tensor]]
141
reward: Union[Tuple[Tensor], List[Tensor]]
142
143
144
class ReplayMemory(object):
145
"""Transition replay memory"""
146
147
def __init__(self, capacity):
148
self.capacity = capacity
149
self.memory = []
150
self.position = 0
151
152
def push(self, *args):
153
"""Saves a transition."""
154
if len(self.memory) < self.capacity:
155
self.memory.append(None)
156
self.memory[self.position] = Transition(*args)
157
self.position = (self.position + 1) % self.capacity
158
159
def sample(self, batch_size):
160
return random.sample(self.memory, batch_size)
161
162
def __len__(self):
163
return len(self.memory)
164
165
166
class DQN(nn.Module):
167
"""The Deep Neural Network used to estimate the Q function"""
168
169
def __init__(self, ep: EnvironmentBounds):
170
super(DQN, self).__init__()
171
172
model = CyberBattleStateActionModel(ep)
173
linear_input_size = len(model.state_space.dim_sizes)
174
output_size = model.action_space.flat_size()
175
176
self.hidden_layer1 = nn.Linear(linear_input_size, 1024)
177
# self.bn1 = nn.BatchNorm1d(256)
178
self.hidden_layer2 = nn.Linear(1024, 512)
179
self.hidden_layer3 = nn.Linear(512, 128)
180
# self.hidden_layer4 = nn.Linear(128, 64)
181
self.head = nn.Linear(128, output_size)
182
183
# Called with either one element to determine next action, or a batch
184
# during optimization. Returns tensor([[left0exp,right0exp]...]).
185
def forward(self, x):
186
x = F.relu(self.hidden_layer1(x))
187
# x = F.dropout(x, p=0.5, training=self.training)
188
x = F.relu(self.hidden_layer2(x))
189
# x = F.dropout(x, p=0.5, training=self.training)
190
x = F.relu(self.hidden_layer3(x))
191
# x = F.relu(self.hidden_layer4(x))
192
return self.head(x.view(x.size(0), -1))
193
194
195
def random_argmax(array):
196
"""Just like `argmax` but if there are multiple elements with the max
197
return a random index to break ties instead of returning the first one."""
198
max_value = np.max(array)
199
max_index = np.where(array == max_value)[0]
200
201
if max_index.shape[0] > 1:
202
max_index = int(np.random.choice(max_index, size=1))
203
else:
204
max_index = int(max_index)
205
206
return max_value, max_index
207
208
209
class ChosenActionMetadata(NamedTuple):
210
"""Additonal info about the action chosen by the DQN-induced policy"""
211
212
abstract_action: np.int32
213
actor_node: int
214
actor_features: ndarray
215
actor_state: ndarray
216
217
def __repr__(self) -> str:
218
return f"[abstract_action={self.abstract_action}, actor={self.actor_node}, state={self.actor_state}]"
219
220
221
class DeepQLearnerPolicy(Learner):
222
"""Deep Q-Learning on CyberBattle environments
223
224
Parameters
225
==========
226
ep -- global parameters of the environment
227
model -- define a state and action abstraction for the gym environment
228
gamma -- Q discount factor
229
replay_memory_size -- size of the replay memory
230
batch_size -- Deep Q-learning batch
231
target_update -- Deep Q-learning replay frequency (in number of episodes)
232
learning_rate -- the learning rate
233
234
Parameters from DeepDoubleQ paper
235
- learning_rate = 0.00025
236
- linear epsilon decay
237
- gamma = 0.99
238
239
Pytorch code from tutorial at
240
https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
241
"""
242
243
def __init__(
244
self,
245
ep: EnvironmentBounds,
246
gamma: float,
247
replay_memory_size: int,
248
target_update: int,
249
batch_size: int,
250
learning_rate: float,
251
):
252
self.stateaction_model = CyberBattleStateActionModel(ep)
253
self.batch_size = batch_size
254
self.gamma = gamma
255
self.learning_rate = learning_rate
256
257
self.policy_net = DQN(ep).to(device)
258
self.target_net = DQN(ep).to(device)
259
self.target_net.load_state_dict(self.policy_net.state_dict())
260
self.target_net.eval()
261
self.target_update = target_update
262
263
self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=learning_rate) # type: ignore
264
self.memory = ReplayMemory(replay_memory_size)
265
266
self.credcache_policy = CredentialCacheExploiter()
267
268
def parameters_as_string(self):
269
return f"γ={self.gamma}, lr={self.learning_rate}, replaymemory={self.memory.capacity},\n" f"batch={self.batch_size}, target_update={self.target_update}"
270
271
def all_parameters_as_string(self) -> str:
272
model = self.stateaction_model
273
return (
274
f"{self.parameters_as_string()}\n"
275
f"dimension={model.state_space.flat_size()}x{model.action_space.flat_size()}, "
276
f"Q={[f.name() for f in model.state_space.feature_selection]} "
277
f"-> 'abstract_action'"
278
)
279
280
def optimize_model(self, norm_clipping=False):
281
if len(self.memory) < self.batch_size:
282
return
283
284
transitions = self.memory.sample(self.batch_size)
285
# converts batch-array of Transitions to Transition of batch-arrays.
286
batch = Transition(*zip(*transitions))
287
288
# Compute a mask of non-final states and concatenate the batch elements
289
# (a final state would've been the one after which simulation ended)
290
non_final_mask = torch.tensor(
291
tuple(map((lambda s: s is not None), batch.next_state)),
292
device=device,
293
dtype=torch.bool,
294
)
295
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
296
state_batch = torch.cat(batch.state)
297
action_batch = torch.cat(batch.action)
298
reward_batch = torch.cat(batch.reward)
299
300
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
301
# columns of actions taken. These are the actions which would've been taken
302
# for each batch state according to policy_net
303
# print(f'state_batch={state_batch.shape} input={len(self.stateaction_model.state_space.dim_sizes)}')
304
output = self.policy_net(state_batch)
305
306
# print(f'output={output.shape} batch.action={transitions[0].action.shape} action_batch={action_batch.shape}')
307
state_action_values = output.gather(1, action_batch)
308
309
# Compute V(s_{t+1}) for all next states.
310
# Expected values of actions for non_final_next_states are computed based
311
# on the "older" target_net; selecting their best reward with max(1)[0].
312
# This is merged based on the mask, such that we'll have either the expected
313
# state value or 0 in case the state was final.
314
next_state_values = torch.zeros(self.batch_size, device=device)
315
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
316
# Compute the expected Q values
317
expected_state_action_values = (next_state_values * self.gamma) + reward_batch
318
319
# Compute Huber loss
320
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
321
322
# Optimize the model
323
self.optimizer.zero_grad()
324
loss.backward()
325
326
# Gradient clipping
327
if norm_clipping:
328
clip_grad_norm_(self.policy_net.parameters(), 1.0)
329
else:
330
for param in self.policy_net.parameters():
331
if param.grad is not None:
332
param.grad.data.clamp_(-1, 1)
333
self.optimizer.step()
334
335
def get_actor_state_vector(self, global_state: ndarray, actor_features: ndarray) -> ndarray:
336
return np.concatenate(
337
(
338
np.array(global_state, dtype=np.float32),
339
np.array(actor_features, dtype=np.float32),
340
)
341
)
342
343
def update_q_function(
344
self,
345
reward: float,
346
actor_state: ndarray,
347
abstract_action: np.int32,
348
next_actor_state: Optional[ndarray],
349
):
350
# store the transition in memory
351
reward_tensor = torch.tensor([reward], device=device, dtype=torch.float)
352
action_tensor = torch.tensor([[np.int_(abstract_action)]], device=device, dtype=torch.long)
353
current_state_tensor = torch.as_tensor(actor_state, dtype=torch.float, device=device).unsqueeze(0)
354
if next_actor_state is None:
355
next_state_tensor = None
356
else:
357
next_state_tensor = torch.as_tensor(next_actor_state, dtype=torch.float, device=device).unsqueeze(0)
358
self.memory.push(current_state_tensor, action_tensor, next_state_tensor, reward_tensor)
359
360
# optimize the target network
361
self.optimize_model()
362
363
def on_step(
364
self,
365
wrapped_env: w.AgentWrapper,
366
observation,
367
reward: float,
368
done: bool,
369
truncated: bool,
370
info,
371
action_metadata,
372
):
373
agent_state = wrapped_env.state
374
if done:
375
self.update_q_function(
376
reward,
377
actor_state=action_metadata.actor_state,
378
abstract_action=action_metadata.abstract_action,
379
next_actor_state=None,
380
)
381
else:
382
next_global_state = self.stateaction_model.global_features.get(agent_state, node=None)
383
next_actor_features = self.stateaction_model.node_specific_features.get(agent_state, action_metadata.actor_node)
384
next_actor_state = self.get_actor_state_vector(next_global_state, next_actor_features)
385
386
self.update_q_function(
387
reward,
388
actor_state=action_metadata.actor_state,
389
abstract_action=action_metadata.abstract_action,
390
next_actor_state=next_actor_state,
391
)
392
393
def end_of_episode(self, i_episode, t):
394
# Update the target network, copying all weights and biases in DQN
395
if i_episode % self.target_update == 0:
396
self.target_net.load_state_dict(self.policy_net.state_dict())
397
398
def lookup_dqn(self, states_to_consider: List[ndarray]) -> Tuple[List[np.int32], List[np.int32]]:
399
"""Given a set of possible current states return:
400
- index, in the provided list, of the state that would yield the best possible outcome
401
- the best action to take in such a state"""
402
with torch.no_grad():
403
# t.max(1) will return largest column value of each row.
404
# second column on max result is index of where max element was
405
# found, so we pick action with the larger expected reward.
406
# action: np.int32 = self.policy_net(states_to_consider).max(1)[1].view(1, 1).item()
407
408
state_batch = torch.tensor(states_to_consider).to(device)
409
dnn_output = self.policy_net(state_batch).max(1)
410
action_lookups = dnn_output[1].tolist()
411
expectedq_lookups = dnn_output[0].tolist()
412
413
return action_lookups, expectedq_lookups
414
415
def metadata_from_gymaction(self, wrapped_env, gym_action):
416
current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)
417
actor_node = cyberbattle_env.sourcenode_of_action(gym_action)
418
actor_features = self.stateaction_model.node_specific_features.get(wrapped_env.state, actor_node)
419
abstract_action = self.stateaction_model.action_space.abstract_from_gymaction(gym_action)
420
return ChosenActionMetadata(
421
abstract_action=abstract_action,
422
actor_node=actor_node,
423
actor_features=actor_features,
424
actor_state=self.get_actor_state_vector(current_global_state, actor_features),
425
)
426
427
def explore(self, wrapped_env: w.AgentWrapper) -> Tuple[str, cyberbattle_env.Action, object]:
428
"""Random exploration that avoids repeating actions previously taken in the same state"""
429
# sample local and remote actions only (excludes connect action)
430
gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
431
metadata = self.metadata_from_gymaction(wrapped_env, gym_action)
432
return "explore", gym_action, metadata
433
434
def try_exploit_at_candidate_actor_states(self, wrapped_env, current_global_state, actor_features, abstract_action):
435
actor_state = self.get_actor_state_vector(current_global_state, actor_features)
436
437
action_style, gym_action, actor_node = self.stateaction_model.implement_action(wrapped_env, actor_features, abstract_action)
438
439
if gym_action:
440
assert actor_node is not None, "actor_node should be set together with gym_action"
441
442
return (
443
action_style,
444
gym_action,
445
ChosenActionMetadata(
446
abstract_action=abstract_action,
447
actor_node=actor_node,
448
actor_features=actor_features,
449
actor_state=actor_state,
450
),
451
)
452
else:
453
# learn the failed exploit attempt in the current state
454
self.update_q_function(
455
reward=0.0,
456
actor_state=actor_state,
457
next_actor_state=actor_state,
458
abstract_action=abstract_action,
459
)
460
461
return "exploit[undefined]->explore", None, None
462
463
def exploit(self, wrapped_env, observation) -> Tuple[str, Optional[cyberbattle_env.Action], object]:
464
# first, attempt to exploit the credential cache
465
# using the crecache_policy
466
# action_style, gym_action, _ = self.credcache_policy.exploit(wrapped_env, observation)
467
# if gym_action:
468
# return action_style, gym_action, self.metadata_from_gymaction(wrapped_env, gym_action)
469
470
# Otherwise on exploit learnt Q-function
471
472
current_global_state = self.stateaction_model.global_features.get(wrapped_env.state, node=None)
473
474
# Gather the features of all the current active actors (i.e. owned nodes)
475
active_actors_features: List[ndarray] = [self.stateaction_model.node_specific_features.get(wrapped_env.state, from_node) for from_node in w.owned_nodes(observation)]
476
477
unique_active_actors_features: List[ndarray] = list(np.unique(active_actors_features, axis=0))
478
479
# array of actor state vector for every possible set of node features
480
candidate_actor_state_vector: List[ndarray] = [self.get_actor_state_vector(current_global_state, node_features) for node_features in unique_active_actors_features]
481
482
remaining_action_lookups, remaining_expectedq_lookups = self.lookup_dqn(candidate_actor_state_vector)
483
remaining_candidate_indices = list(range(len(candidate_actor_state_vector)))
484
485
while remaining_candidate_indices:
486
_, remaining_candidate_index = random_argmax(remaining_expectedq_lookups)
487
actor_index = remaining_candidate_indices[remaining_candidate_index]
488
abstract_action = remaining_action_lookups[remaining_candidate_index]
489
490
actor_features = unique_active_actors_features[actor_index]
491
492
action_style, gym_action, metadata = self.try_exploit_at_candidate_actor_states(wrapped_env, current_global_state, actor_features, abstract_action)
493
494
if gym_action:
495
return action_style, gym_action, metadata
496
497
remaining_candidate_indices.pop(remaining_candidate_index)
498
remaining_expectedq_lookups.pop(remaining_candidate_index)
499
remaining_action_lookups.pop(remaining_candidate_index)
500
501
return "exploit[undefined]->explore", None, None
502
503
def stateaction_as_string(self, action_metadata) -> str:
504
return ""
505
506