Path: blob/main/cyberbattle/agents/baseline/agent_randomcredlookup.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Random agent with credential lookup (notebook)"""45# pylint: disable=invalid-name67from .agent_wrapper import AgentWrapper8from .learner import Learner9from typing import Optional10import cyberbattle._env.cyberbattle_env as cyberbattle_env11import numpy as np12import logging13import cyberbattle.agents.baseline.agent_wrapper as w141516def exploit_credentialcache(observation) -> Optional[cyberbattle_env.Action]:17"""Exploit the credential cache to connect to18a node not owned yet."""1920# Pick source node at random (owned and with the desired feature encoding)21potential_source_nodes = w.owned_nodes(observation)22if len(potential_source_nodes) == 0:23return None2425source_node = np.random.choice(potential_source_nodes)2627discovered_credentials = np.array(observation["credential_cache_matrix"])28n_discovered_creds = len(discovered_credentials)29if n_discovered_creds <= 0:30# no credential available in the cache: cannot poduce a valid connect action31return None3233nodes_not_owned = w.discovered_nodes_notowned(observation)3435match_port__target_notowned = [c for c in range(n_discovered_creds) if discovered_credentials[c, 0] in nodes_not_owned]3637if match_port__target_notowned:38logging.debug("found matching cred in the credential cache")39cred = np.int32(np.random.choice(match_port__target_notowned))40target = np.int32(discovered_credentials[cred, 0])41port = np.int32(discovered_credentials[cred, 1])42return {"connect": np.array([source_node, target, port, cred], dtype=np.int32)}43else:44return None454647class CredentialCacheExploiter(Learner):48"""A learner that just exploits the credential cache"""4950def parameters_as_string(self):51return ""5253def explore(self, wrapped_env: AgentWrapper):54return "explore", wrapped_env.env.sample_valid_action([0, 1]), None5556def exploit(self, wrapped_env: AgentWrapper, observation):57gym_action = exploit_credentialcache(observation)58if gym_action:59if wrapped_env.env.is_action_valid(gym_action, observation["action_mask"]):60return "exploit", gym_action, None61else:62# fallback on random exploration63return "exploit[invalid]->explore", None, None64else:65return "exploit[undefined]->explore", None, None6667def stateaction_as_string(self, action_metadata):68return ""6970def on_step(71self,72wrapped_env: AgentWrapper,73observation,74reward,75done,76truncated,77info,78action_metadata,79):80return None8182def end_of_iteration(self, t, done):83return None8485def end_of_episode(self, i_episode, t):86return None8788def loss_as_string(self):89return ""9091def new_episode(self):92return None939495