Path: blob/main/cyberbattle/simulation/actions.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""4actions.py5This file contains the class and associated methods for the AgentActions6class which interacts directly with the environment. It is the class7which both the user and RL agents should manipulate the environment.8"""910from dataclasses import dataclass11import dataclasses12import datetime13from boolean import boolean14from collections import OrderedDict15import logging16from enum import Enum17from typing import (18Iterator,19List,20NamedTuple,21Optional,22Set,23Tuple,24Dict,25TypedDict,26cast,27)28from IPython.display import display29import pandas as pd3031from cyberbattle.simulation.model import (32FirewallRule,33MachineStatus,34PrivilegeLevel,35PropertyName,36VulnerabilityID,37VulnerabilityType,38)39from . import model404142logger = logging.getLogger(__name__)43Reward = float4445DiscoveredNodeInfo = TypedDict("DiscoveredNodeInfo", {"id": model.NodeID, "status": str})464748class Penalty:49"""Penalties (=negative reward) returned for some actions taken in the simulation"""5051# penalty for generic suspiciousness52SUPSPICIOUSNESS = -5.05354# penalty for attempting a connection to a port that was not open55SCANNING_UNOPEN_PORT = -10.05657# penalty for repeating the same exploit attempt58REPEAT = -15960LOCAL_EXPLOIT_FAILED = -2061FAILED_REMOTE_EXPLOIT = -506263# penalty for attempting to connect or execute an action on a node that's not in running state64MACHINE_NOT_RUNNING = 06566# penalty for attempting a connection with an invalid password67WRONG_PASSWORD = -106869# traffic blocked by outoing rule in a local firewall70BLOCKED_BY_LOCAL_FIREWALL = -107172# traffic blocked by incoming rule in a remote firewall73BLOCKED_BY_REMOTE_FIREWALL = -107475# invalid action (e.g., running an attack from a node that's not owned)76# (Used only if `throws_on_invalid_actions` is set to False)77INVALID_ACTION = -1787980# Reward for the first time a local or remote attack81# gets successfully executed since the last time the target node was imaged.82# NOTE: the attack cost gets substracted from this reward.83NEW_SUCCESSFULL_ATTACK_REWARD = 78485# Fixed reward for discovering a new node86NODE_DISCOVERED_REWARD = 58788# Fixed reward for discovering a new credential89CREDENTIAL_DISCOVERED_REWARD = 39091# Fixed reward for discovering a new node property92PROPERTY_DISCOVERED_REWARD = 2939495class EdgeAnnotation(Enum):96"""Annotation added to the network edges created as the simulation is played"""9798KNOWS = 099REMOTE_EXPLOIT = 1100LATERAL_MOVE = 2101102103class ActionResult(NamedTuple):104"""Result from executing an action"""105106reward: Reward107outcome: Optional[model.VulnerabilityOutcome]108109110ALGEBRA = boolean.BooleanAlgebra()111112113@dataclass114class NodeTrackingInformation:115"""Track information about nodes gathered throughout the simulation"""116117# Map (vulnid, local_or_remote) to time of last attack.118# local_or_remote is true for local, false for remote119last_attack: Dict[Tuple[model.VulnerabilityID, bool], datetime.datetime] = dataclasses.field(default_factory=dict)120# Last time the node got owned by the attacker agent121last_owned_at: Optional[datetime.datetime] = None122# All node properties discovered so far123discovered_properties: Set[int] = dataclasses.field(default_factory=set)124125126class AgentActions:127"""128This is the AgentActions class. It interacts with and makes changes to the environment.129"""130131def __init__(self, environment: model.Environment, throws_on_invalid_actions=True):132"""133AgentActions Constructor134135environment - CyberBattleSim environment parameters136throws_on_invalid_actions - whether to raise an exception when executing an invalid action (e.g., running an attack from a node that's not owned)137if set to False a negative reward is returned instead.138139"""140self._environment = environment141self._gathered_credentials: Set[model.CredentialID] = set()142self._discovered_nodes: "OrderedDict[model.NodeID, NodeTrackingInformation]" = OrderedDict()143self._throws_on_invalid_actions = throws_on_invalid_actions144145# List of all special tags indicating a privilege level reached on a node146self.privilege_tags = [model.PrivilegeEscalation(p).tag for p in list(PrivilegeLevel)]147148# Mark all owned nodes as discovered149for i, node in environment.nodes():150if node.agent_installed:151self.__mark_node_as_owned(i, PrivilegeLevel.LocalUser)152153def discovered_nodes(self) -> Iterator[Tuple[model.NodeID, model.NodeInfo]]:154for node_id in self._discovered_nodes:155yield (node_id, self._environment.get_node(node_id))156157def _check_prerequisites(self, target: model.NodeID, vulnerability: model.VulnerabilityInfo) -> bool:158"""159This is a quick helper function to check the prerequisites to see if160they match the ones supplied.161"""162node: model.NodeInfo = self._environment.network.nodes[target]["data"]163node_flags = node.properties164expr = vulnerability.precondition.expression165166true_value = ALGEBRA.parse("true")167false_value = ALGEBRA.parse("false")168mapping = {i: true_value if str(i) in node_flags else false_value for i in expr.get_symbols()}169is_true: bool = cast(boolean.Expression, expr.subs(mapping)).simplify() == true_value170return is_true171172def list_vulnerabilities_in_target(173self,174target: model.NodeID,175type_filter: Optional[model.VulnerabilityType] = None,176) -> List[model.VulnerabilityID]:177"""178This function takes a model.NodeID for the target to be scanned179and returns a list of vulnerability IDs.180It checks each vulnerability in the library against the the properties of a given node181and determines which vulnerabilities it has.182"""183if not self._environment.network.has_node(target):184raise ValueError(f"invalid node id '{target}'")185186target_node_data: model.NodeInfo = self._environment.get_node(target)187188global_vuln: Set[model.VulnerabilityID] = {189vuln_id190for vuln_id, vulnerability in self._environment.vulnerability_library.items()191if (type_filter is None or vulnerability.type == type_filter) and self._check_prerequisites(target, vulnerability)192}193194local_vuln: Set[model.VulnerabilityID] = {195vuln_id196for vuln_id, vulnerability in target_node_data.vulnerabilities.items()197if (type_filter is None or vulnerability.type == type_filter) and self._check_prerequisites(target, vulnerability)198}199200return list(global_vuln.union(local_vuln))201202def __annotate_edge(203self,204source_node_id: model.NodeID,205target_node_id: model.NodeID,206new_annotation: EdgeAnnotation,207) -> None:208"""Create the edge if it does not already exist, and annotate with the maximum209of the existing annotation and a specified new annotation"""210edge_annotation = self._environment.network.get_edge_data(source_node_id, target_node_id)211if edge_annotation is not None:212if "kind" in edge_annotation:213new_annotation = EdgeAnnotation(max(edge_annotation["kind"].value, new_annotation.value))214else:215new_annotation = EdgeAnnotation(new_annotation.value)216self._environment.network.add_edge(217source_node_id,218target_node_id,219kind=new_annotation,220kind_as_float=float(new_annotation.value),221)222223def get_discovered_properties(self, node_id: model.NodeID) -> Set[int]:224return self._discovered_nodes[node_id].discovered_properties225226def __mark_node_as_discovered(self, node_id: model.NodeID) -> bool:227logger.info("discovered node: " + node_id)228newly_discovered = node_id not in self._discovered_nodes229if newly_discovered:230self._discovered_nodes[node_id] = NodeTrackingInformation()231return newly_discovered232233def __mark_nodeproperties_as_discovered(self, node_id: model.NodeID, properties: List[PropertyName]):234properties_indices = [self._environment.identifiers.properties.index(p) for p in properties if p not in self.privilege_tags]235236if node_id in self._discovered_nodes:237before_count = len(self._discovered_nodes[node_id].discovered_properties)238self._discovered_nodes[node_id].discovered_properties = self._discovered_nodes[node_id].discovered_properties.union(properties_indices)239else:240before_count = 0241self._discovered_nodes[node_id] = NodeTrackingInformation(discovered_properties=set(properties_indices))242243newly_discovered_properties = len(self._discovered_nodes[node_id].discovered_properties) - before_count244return newly_discovered_properties245246def __mark_allnodeproperties_as_discovered(self, node_id: model.NodeID):247node_info: model.NodeInfo = self._environment.network.nodes[node_id]["data"]248return self.__mark_nodeproperties_as_discovered(node_id, node_info.properties)249250def __mark_node_as_owned(251self,252node_id: model.NodeID,253privilege: PrivilegeLevel = model.PrivilegeLevel.LocalUser,254) -> Tuple[Optional[datetime.datetime], bool]:255"""Mark a node as owned.256Return the time it was previously own (or None) and whether it was already owned.257"""258node_info = self._environment.get_node(node_id)259260last_owned_at, is_currently_owned = self.__is_node_owned_history(node_id, node_info)261262if not is_currently_owned:263if node_id not in self._discovered_nodes:264self._discovered_nodes[node_id] = NodeTrackingInformation()265node_info.agent_installed = True266node_info.privilege_level = model.escalate(node_info.privilege_level, privilege)267self._environment.network.nodes[node_id].update({"data": node_info})268269self.__mark_allnodeproperties_as_discovered(node_id)270271# Record that the node just got owned at the current time272self._discovered_nodes[node_id].last_owned_at = datetime.datetime.now()273274return last_owned_at, is_currently_owned275276def __mark_discovered_entities(self, reference_node: model.NodeID, outcome: model.VulnerabilityOutcome) -> Tuple[int, float, int]:277"""Mark discovered entities as such and return278the number of newly discovered nodes, their total value and the number of newly discovered credentials279"""280newly_discovered_nodes = 0281newly_discovered_nodes_value = 0282newly_discovered_credentials = 0283284if isinstance(outcome, model.LeakedCredentials):285for credential in outcome.credentials:286if self.__mark_node_as_discovered(credential.node):287newly_discovered_nodes += 1288newly_discovered_nodes_value += self._environment.get_node(credential.node).value289290if credential.credential not in self._gathered_credentials:291newly_discovered_credentials += 1292self._gathered_credentials.add(credential.credential)293294logger.info("discovered credential: " + str(credential))295self.__annotate_edge(reference_node, credential.node, EdgeAnnotation.KNOWS)296297elif isinstance(outcome, model.LeakedNodesId):298for node_id in outcome.nodes:299if self.__mark_node_as_discovered(node_id):300newly_discovered_nodes += 1301newly_discovered_nodes_value += self._environment.get_node(node_id).value302303self.__annotate_edge(reference_node, node_id, EdgeAnnotation.KNOWS)304305return (306newly_discovered_nodes,307newly_discovered_nodes_value,308newly_discovered_credentials,309)310311def get_node_privilegelevel(self, node_id: model.NodeID) -> model.PrivilegeLevel:312"""Return the last recorded privilege level of the specified node"""313node_info = self._environment.get_node(node_id)314return node_info.privilege_level315316def get_nodes_with_atleast_privilegelevel(self, level: PrivilegeLevel) -> List[model.NodeID]:317"""Return all nodes with at least the specified privilege level"""318return [n for n, info in self._environment.nodes() if info.privilege_level >= level]319320def is_node_discovered(self, node_id: model.NodeID) -> bool:321"""Returns true if previous actions have revealed the specified node ID"""322return node_id in self._discovered_nodes323324def __process_outcome(325self,326expected_type: VulnerabilityType,327vulnerability_id: VulnerabilityID,328node_id: model.NodeID,329node_info: model.NodeInfo,330local_or_remote: bool,331failed_penalty: float,332throw_if_vulnerability_not_present: bool,333) -> Tuple[bool, ActionResult]:334if node_info.status != model.MachineStatus.Running:335logger.info("target machine not in running state")336return False, ActionResult(reward=Penalty.MACHINE_NOT_RUNNING, outcome=None)337338is_global_vulnerability = vulnerability_id in self._environment.vulnerability_library339is_inplace_vulnerability = vulnerability_id in node_info.vulnerabilities340341if is_global_vulnerability:342vulnerabilities = self._environment.vulnerability_library343elif is_inplace_vulnerability:344vulnerabilities = node_info.vulnerabilities345else:346if throw_if_vulnerability_not_present:347raise ValueError(f"Vulnerability '{vulnerability_id}' not supported by node='{node_id}'")348else:349logger.info(f"Vulnerability '{vulnerability_id}' not supported by node '{node_id}'")350return False, ActionResult(reward=Penalty.SUPSPICIOUSNESS, outcome=None)351352vulnerability = vulnerabilities[vulnerability_id]353354outcome = vulnerability.outcome355356if vulnerability.type != expected_type:357raise ValueError(f"vulnerability id '{vulnerability_id}' is for an attack of type {vulnerability.type}, expecting: {expected_type}")358359# check vulnerability prerequisites360if not self._check_prerequisites(node_id, vulnerability):361return False, ActionResult(reward=failed_penalty, outcome=model.ExploitFailed())362363reward = 0364365# if the vulnerability type is a privilege escalation366# and if the escalation level is not already reached on that node,367# then add the escalation tag to the node properties368if isinstance(outcome, model.PrivilegeEscalation):369if outcome.tag in node_info.properties:370return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)371372last_owned_at, is_currently_owned = self.__mark_node_as_owned(node_id, outcome.level)373374if not last_owned_at:375reward += float(node_info.value)376377node_info.properties.append(outcome.tag)378379elif isinstance(outcome, model.LateralMove):380last_owned_at, is_currently_owned = self.__mark_node_as_owned(node_id)381382if not last_owned_at:383reward += float(node_info.value)384385elif isinstance(outcome, model.ProbeSucceeded):386for p in outcome.discovered_properties:387assert p in node_info.properties, f"Discovered property {p} must belong to the set of properties associated with the node."388389newly_discovered_properties = self.__mark_nodeproperties_as_discovered(node_id, outcome.discovered_properties)390reward += newly_discovered_properties * PROPERTY_DISCOVERED_REWARD391392if node_id not in self._discovered_nodes:393self._discovered_nodes[node_id] = NodeTrackingInformation()394395lookup_key = (vulnerability_id, local_or_remote)396397already_executed = lookup_key in self._discovered_nodes[node_id].last_attack398399if already_executed:400last_time = self._discovered_nodes[node_id].last_attack[lookup_key]401if node_info.last_reimaging is None or last_time >= node_info.last_reimaging:402reward += Penalty.REPEAT403else:404reward += NEW_SUCCESSFULL_ATTACK_REWARD405406self._discovered_nodes[node_id].last_attack[lookup_key] = datetime.datetime.now()407408(409newly_discovered_nodes,410discovered_nodes_value,411newly_discovered_credentials,412) = self.__mark_discovered_entities(node_id, outcome)413414# Note: `discovered_nodes_value` should not be added to the reward415# unless the discovered nodes got owned, but this case is already covered above416reward += newly_discovered_nodes * NODE_DISCOVERED_REWARD417reward += newly_discovered_credentials * CREDENTIAL_DISCOVERED_REWARD418419reward -= vulnerability.cost420421logger.info("GOT REWARD: " + vulnerability.reward_string)422return True, ActionResult(reward=reward, outcome=outcome)423424def exploit_remote_vulnerability(425self,426node_id: model.NodeID,427target_node_id: model.NodeID,428vulnerability_id: model.VulnerabilityID,429) -> ActionResult:430"""431Attempt to exploit a remote vulnerability432from a source node to another node using the specified433vulnerability.434"""435if node_id not in self._environment.network.nodes:436raise ValueError(f"invalid node id '{node_id}'")437if target_node_id not in self._environment.network.nodes:438raise ValueError(f"invalid target node id '{target_node_id}'")439440source_node_info: model.NodeInfo = self._environment.get_node(node_id)441target_node_info: model.NodeInfo = self._environment.get_node(target_node_id)442443if not source_node_info.agent_installed:444if self._throws_on_invalid_actions:445raise ValueError("Agent does not owned the source node '" + node_id + "'")446else:447return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)448449if target_node_id not in self._discovered_nodes:450if self._throws_on_invalid_actions:451raise ValueError("Agent has not discovered the target node '" + target_node_id + "'")452else:453return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)454455succeeded, result = self.__process_outcome(456model.VulnerabilityType.REMOTE,457vulnerability_id,458target_node_id,459target_node_info,460local_or_remote=False,461failed_penalty=Penalty.FAILED_REMOTE_EXPLOIT,462# We do not throw if the vulnerability is missing in order to463# allow agent attempts to explore potential remote vulnerabilities464throw_if_vulnerability_not_present=False,465)466467if succeeded:468self.__annotate_edge(node_id, target_node_id, EdgeAnnotation.REMOTE_EXPLOIT)469470return result471472def exploit_local_vulnerability(self, node_id: model.NodeID, vulnerability_id: model.VulnerabilityID) -> ActionResult:473"""474This function exploits a local vulnerability on a node475it takes a nodeID for the target and a vulnerability ID.476477It returns either a vulnerabilityoutcome object or None478"""479graph = self._environment.network480if node_id not in graph.nodes:481raise ValueError(f"invalid node id '{node_id}'")482483node_info = self._environment.get_node(node_id)484485if not node_info.agent_installed:486if self._throws_on_invalid_actions:487raise ValueError(f"Agent does not owned the node '{node_id}'")488else:489return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)490491succeeded, result = self.__process_outcome(492model.VulnerabilityType.LOCAL,493vulnerability_id,494node_id,495node_info,496local_or_remote=True,497failed_penalty=Penalty.LOCAL_EXPLOIT_FAILED,498throw_if_vulnerability_not_present=False,499)500501return result502503def __is_passing_firewall_rules(self, rules: List[model.FirewallRule], port_name: model.PortName) -> bool:504"""Determine if traffic on the specified port is permitted by the specified sets of firewall rules"""505for rule in rules:506if rule.port == port_name:507if rule.permission == model.RulePermission.ALLOW:508return True509else:510logger.debug(f"BLOCKED TRAFFIC - PORT '{port_name}' Reason: " + rule.reason)511return False512513logger.debug(f"BLOCKED TRAFFIC - PORT '{port_name}' - Reason: no rule defined for this port.")514return False515516def __is_node_owned_history(self, target_node_id, target_node_data):517"""Returns the last time the node got owned and whether it is still currently owned."""518last_previously_owned_at = self._discovered_nodes[target_node_id].last_owned_at if target_node_id in self._discovered_nodes else None519520is_currently_owned = last_previously_owned_at is not None and (target_node_data.last_reimaging is None or last_previously_owned_at >= target_node_data.last_reimaging)521return last_previously_owned_at, is_currently_owned522523def connect_to_remote_machine(524self,525source_node_id: model.NodeID,526target_node_id: model.NodeID,527port_name: model.PortName,528credential: model.CredentialID,529) -> ActionResult:530"""531This function connects to a remote machine with credential as opposed to via an exploit.532It takes a NodeId for the source machine, a NodeID for the target Machine, and a credential object533for the credential.534"""535graph = self._environment.network536if source_node_id not in graph.nodes:537raise ValueError(f"invalid node id '{source_node_id}'")538if target_node_id not in graph.nodes:539raise ValueError(f"invalid node id '{target_node_id}''")540541target_node = self._environment.get_node(target_node_id)542source_node = self._environment.get_node(source_node_id)543# ensures that the source node is owned by the agent544# and that the target node is discovered545546if not source_node.agent_installed:547if self._throws_on_invalid_actions:548raise ValueError(f"Agent does not owned the source node '{source_node_id}'")549else:550return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)551552if target_node_id not in self._discovered_nodes:553if self._throws_on_invalid_actions:554raise ValueError(f"Agent has not discovered the target node '{target_node_id}'")555else:556return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)557558if credential not in self._gathered_credentials:559if self._throws_on_invalid_actions:560raise ValueError(f"Agent has not discovered credential '{credential}'")561else:562return ActionResult(reward=Penalty.INVALID_ACTION, outcome=None)563564if not self.__is_passing_firewall_rules(source_node.firewall.outgoing, port_name):565logger.info(f"BLOCKED TRAFFIC: source node '{source_node_id}'" + f" is blocking outgoing traffic on port '{port_name}'")566return ActionResult(reward=Penalty.BLOCKED_BY_LOCAL_FIREWALL, outcome=None)567568if not self.__is_passing_firewall_rules(target_node.firewall.incoming, port_name):569logger.info(f"BLOCKED TRAFFIC: target node '{target_node_id}'" + f" is blocking incoming traffic on port '{port_name}'")570return ActionResult(reward=Penalty.BLOCKED_BY_REMOTE_FIREWALL, outcome=None)571572target_node_is_listening = port_name in [i.name for i in target_node.services]573if not target_node_is_listening:574logger.info(f"target node '{target_node_id}' not listening on port '{port_name}'")575return ActionResult(reward=Penalty.SCANNING_UNOPEN_PORT, outcome=None)576else:577target_node_data: model.NodeInfo = self._environment.get_node(target_node_id)578579if target_node_data.status != model.MachineStatus.Running:580logger.info("target machine not in running state")581return ActionResult(reward=Penalty.MACHINE_NOT_RUNNING, outcome=None)582583# check the credentials before connecting584if not self._check_service_running_and_authorized(target_node_data, port_name, credential):585logger.info("invalid credentials supplied")586return ActionResult(reward=Penalty.WRONG_PASSWORD, outcome=None)587588last_owned_at, is_already_owned = self.__mark_node_as_owned(target_node_id)589590if is_already_owned:591return ActionResult(reward=Penalty.REPEAT, outcome=model.LateralMove())592593if target_node_id not in self._discovered_nodes:594self._discovered_nodes[target_node_id] = NodeTrackingInformation()595596self.__annotate_edge(source_node_id, target_node_id, EdgeAnnotation.LATERAL_MOVE)597598logger.info(f"Infected node '{target_node_id}' from '{source_node_id}'" + f" via {port_name} with credential '{credential}'")599if target_node.owned_string:600logger.info("Owned message: " + target_node.owned_string)601602return ActionResult(603reward=float(target_node_data.value) if last_owned_at is None else 0.0,604outcome=model.LateralMove(),605)606607def _check_service_running_and_authorized(608self,609target_node_data: model.NodeInfo,610port_name: model.PortName,611credential: model.CredentialID,612) -> bool:613"""614This is a quick helper function to check the prerequisites to see if615they match the ones supplied.616"""617for service in target_node_data.services:618if service.running and service.name == port_name and credential in service.allowedCredentials:619return True620return False621622def list_nodes(self) -> List[DiscoveredNodeInfo]:623"""Returns the list of nodes ID that were discovered or owned by the attacker."""624return [625cast(626DiscoveredNodeInfo,627{628"id": node_id,629"status": "owned" if node_info.agent_installed else "discovered",630},631)632for node_id, node_info in self.discovered_nodes()633]634635def list_remote_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:636"""Return list of all remote attacks that may be executed onto the specified node."""637attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(node_id, model.VulnerabilityType.REMOTE)638return attacks639640def list_local_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:641"""Return list of all local attacks that may be executed onto the specified node."""642attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(node_id, model.VulnerabilityType.LOCAL)643return attacks644645def list_attacks(self, node_id: model.NodeID) -> List[model.VulnerabilityID]:646"""Return list of all attacks that may be executed on the specified node."""647attacks: List[model.VulnerabilityID] = self.list_vulnerabilities_in_target(node_id)648return attacks649650def list_all_attacks(self) -> List[Dict[str, object]]:651"""List all possible attacks from all the nodes currently owned by the attacker"""652on_owned_nodes: List[Dict[str, object]] = [653{654"id": n["id"],655"status": n["status"],656"properties": self._environment.get_node(n["id"]).properties,657"local_attacks": self.list_local_attacks(n["id"]),658"remote_attacks": self.list_remote_attacks(n["id"]),659}660for n in self.list_nodes()661if n["status"] == "owned"662]663on_discovered_nodes: List[Dict[str, object]] = [664{665"id": n["id"],666"status": n["status"],667"local_attacks": None,668"remote_attacks": self.list_remote_attacks(n["id"]),669}670for n in self.list_nodes()671if n["status"] != "owned"672]673return on_owned_nodes + on_discovered_nodes674675def print_all_attacks(self) -> None:676"""Pretty print list of all possible attacks from all the nodes currently owned by the attacker"""677display(pd.DataFrame.from_dict(self.list_all_attacks()).set_index("id")) # type: ignore678679680class DefenderAgentActions:681"""Actions reserved to defender agents"""682683# Number of steps it takes to completely reimage a node684REIMAGING_DURATION = 15685686def __init__(self, environment: model.Environment):687# map nodes being reimaged to the remaining number of steps to completion688self.node_reimaging_progress: Dict[model.NodeID, int] = dict()689690# Last calculated availability of the network691self.__network_availability: float = 1.0692693self._environment = environment694695@property696def network_availability(self):697return self.__network_availability698699def reimage_node(self, node_id: model.NodeID):700"""Re-image a computer node"""701# Mark the node for re-imaging and make it unavailable until re-imaging completes702self.node_reimaging_progress[node_id] = self.REIMAGING_DURATION703704node_info = self._environment.get_node(node_id)705assert node_info.reimagable, f"Node {node_id} is not re-imageable"706707node_info.agent_installed = False708node_info.privilege_level = model.PrivilegeLevel.NoAccess709node_info.status = model.MachineStatus.Imaging710node_info.last_reimaging = datetime.datetime.now()711self._environment.network.nodes[node_id].update({"data": node_info})712713def on_attacker_step_taken(self):714"""Function to be called each time a step is take in the simulation"""715for node_id in list(self.node_reimaging_progress.keys()):716remaining_steps = self.node_reimaging_progress[node_id]717if remaining_steps > 0:718self.node_reimaging_progress[node_id] -= 1719else:720logger.info(f"Machine re-imaging completed: {node_id}")721node_data = self._environment.get_node(node_id)722node_data.status = model.MachineStatus.Running723self.node_reimaging_progress.pop(node_id)724725# Calculate the network availability metric based on machines726# and services that are running727total_node_weights = 0728network_node_availability = 0729for node_id, node_info in self._environment.nodes():730total_service_weights = 0731running_service_weights = 0732for service in node_info.services:733total_service_weights += service.sla_weight734running_service_weights += service.sla_weight * int(service.running)735736if node_info.status == MachineStatus.Running:737adjusted_node_availability = (1 + running_service_weights) / (1 + total_service_weights)738else:739adjusted_node_availability = 0.0740741total_node_weights += node_info.sla_weight742network_node_availability += adjusted_node_availability * node_info.sla_weight743744self.__network_availability = network_node_availability / total_node_weights745assert self.__network_availability <= 1.0 and self.__network_availability >= 0.0746747def override_firewall_rule(748self,749node_id: model.NodeID,750port_name: model.PortName,751incoming: bool,752permission: model.RulePermission,753):754node_data = self._environment.get_node(node_id)755756def add_or_patch_rule(rules) -> List[FirewallRule]:757new_rules = []758has_matching_rule = False759for r in rules:760if r.port == port_name:761has_matching_rule = True762new_rules.append(FirewallRule(r.port, permission))763else:764new_rules.append(r)765766if not has_matching_rule:767new_rules.append(model.FirewallRule(port_name, permission))768return new_rules769770if incoming:771node_data.firewall.incoming = add_or_patch_rule(node_data.firewall.incoming)772else:773node_data.firewall.outgoing = add_or_patch_rule(node_data.firewall.outgoing)774775def block_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):776return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.BLOCK)777778def allow_traffic(self, node_id: model.NodeID, port_name: model.PortName, incoming: bool):779return self.override_firewall_rule(node_id, port_name, incoming, permission=model.RulePermission.ALLOW)780781def stop_service(self, node_id: model.NodeID, port_name: model.PortName):782node_data = self._environment.get_node(node_id)783assert node_data.status == model.MachineStatus.Running, "Machine must be running to stop a service"784for service in node_data.services:785if service.name == port_name:786service.running = False787788def start_service(self, node_id: model.NodeID, port_name: model.PortName):789node_data = self._environment.get_node(node_id)790assert node_data.status == model.MachineStatus.Running, "Machine must be running to start a service"791for service in node_data.services:792if service.name == port_name:793service.running = True794795796