Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/agents/baseline/agent_randomcredlookup.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""Random agent with credential lookup (notebook)"""
5
6
# pylint: disable=invalid-name
7
8
from .agent_wrapper import AgentWrapper
9
from .learner import Learner
10
from typing import Optional
11
import cyberbattle._env.cyberbattle_env as cyberbattle_env
12
import numpy as np
13
import logging
14
import cyberbattle.agents.baseline.agent_wrapper as w
15
16
17
def exploit_credentialcache(observation) -> Optional[cyberbattle_env.Action]:
18
"""Exploit the credential cache to connect to
19
a node not owned yet."""
20
21
# Pick source node at random (owned and with the desired feature encoding)
22
potential_source_nodes = w.owned_nodes(observation)
23
if len(potential_source_nodes) == 0:
24
return None
25
26
source_node = np.random.choice(potential_source_nodes)
27
28
discovered_credentials = np.array(observation["credential_cache_matrix"])
29
n_discovered_creds = len(discovered_credentials)
30
if n_discovered_creds <= 0:
31
# no credential available in the cache: cannot poduce a valid connect action
32
return None
33
34
nodes_not_owned = w.discovered_nodes_notowned(observation)
35
36
match_port__target_notowned = [c for c in range(n_discovered_creds) if discovered_credentials[c, 0] in nodes_not_owned]
37
38
if match_port__target_notowned:
39
logging.debug("found matching cred in the credential cache")
40
cred = np.int32(np.random.choice(match_port__target_notowned))
41
target = np.int32(discovered_credentials[cred, 0])
42
port = np.int32(discovered_credentials[cred, 1])
43
return {"connect": np.array([source_node, target, port, cred], dtype=np.int32)}
44
else:
45
return None
46
47
48
class CredentialCacheExploiter(Learner):
49
"""A learner that just exploits the credential cache"""
50
51
def parameters_as_string(self):
52
return ""
53
54
def explore(self, wrapped_env: AgentWrapper):
55
return "explore", wrapped_env.env.sample_valid_action([0, 1]), None
56
57
def exploit(self, wrapped_env: AgentWrapper, observation):
58
gym_action = exploit_credentialcache(observation)
59
if gym_action:
60
if wrapped_env.env.is_action_valid(gym_action, observation["action_mask"]):
61
return "exploit", gym_action, None
62
else:
63
# fallback on random exploration
64
return "exploit[invalid]->explore", None, None
65
else:
66
return "exploit[undefined]->explore", None, None
67
68
def stateaction_as_string(self, action_metadata):
69
return ""
70
71
def on_step(
72
self,
73
wrapped_env: AgentWrapper,
74
observation,
75
reward,
76
done,
77
truncated,
78
info,
79
action_metadata,
80
):
81
return None
82
83
def end_of_iteration(self, t, done):
84
return None
85
86
def end_of_episode(self, i_episode, t):
87
return None
88
89
def loss_as_string(self):
90
return ""
91
92
def new_episode(self):
93
return None
94
95