Path: blob/main/cyberbattle/simulation/model.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Data model for the simulation environment.45The simulation environment is given by the directed graph6formally defined by:78Node := NodeID x ListeningService[] x Value x Vulnerability[] x FirewallConfig9Edge := NodeID x NodeID x PortName1011where:12- NodeID: string13- ListeningService : Name x AllowedCredentials14- AllowedCredentials : string[] # credential pair represented by just a15string ID16- Value : [0...100] # Intrinsic value of reaching this node17- Vulnerability : VulnerabilityID x Type x Precondition x Outcome x Rates18- VulnerabilityID : string19- Rates : ProbingDetectionRate x ExploitDetectionRate x SuccessRate20- FirewallConfig: {21outgoing : FirwallRule[]22incoming : FirwallRule [] }23- FirewallRule: PortName x { ALLOW, BLOCK }24"""2526from datetime import datetime27from typing import NamedTuple, List, Dict, Optional, Union, Tuple, Iterator28import dataclasses29from dataclasses import dataclass, field30import matplotlib.pyplot as plt # type:ignore31from enum import Enum, IntEnum32from boolean import boolean33import networkx as nx34import yaml35import random3637import matplotlib # type: ignore3839matplotlib.use("Agg")4041VERSION_TAG = "0.1.0"4243ALGEBRA = boolean.BooleanAlgebra()4445# Type alias for identifiers46NodeID = str4748# A unique identifier49ID = str5051# a (login,password/token) credential pair is abstracted as just a unique52# string identifier53CredentialID = str5455# Intrinsic value of a reaching a given node in [0,100]56NodeValue = int575859PortName = str606162@dataclass63class ListeningService:64"""A service port on a given node accepting connection initiated65with the specified allowed credentials"""6667# Name of the port the service is listening to68name: PortName69# credential allowed to authenticate with the service70allowedCredentials: List[CredentialID] = dataclasses.field(default_factory=list)71# whether the service is running or stopped72running: bool = True73# Weight used to evaluate the cost of not running the service74sla_weight = 1.0757677x = ListeningService(name="d")78VulnerabilityID = str7980# Probability rate81Probability = float8283# The name of a node property indicating the presence of a84# service, component, feature or vulnerability on a given node.85PropertyName = str868788class Rates(NamedTuple):89"""Probabilities associated with a given vulnerability"""9091probingDetectionRate: Probability = 0.092exploitDetectionRate: Probability = 0.093successRate: Probability = 1.0949596class VulnerabilityType(Enum):97"""Is the vulnerability exploitable locally or remotely?"""9899LOCAL = 1100REMOTE = 2101102103class PrivilegeLevel(IntEnum):104"""Access privilege level on a given node"""105106NoAccess = 0107LocalUser = 1108Admin = 2109System = 3110MAXIMUM = 3111112113def escalate(current_level, escalation_level: PrivilegeLevel) -> PrivilegeLevel:114return PrivilegeLevel(max(int(current_level), int(escalation_level)))115116117class VulnerabilityOutcome:118"""Outcome of exploiting a given vulnerability"""119120121class LateralMove(VulnerabilityOutcome):122"""Lateral movement to the target node"""123124success: bool125126127class CustomerData(VulnerabilityOutcome):128"""Access customer data on target node"""129130131class PrivilegeEscalation(VulnerabilityOutcome):132"""Privilege escalation outcome"""133134def __init__(self, level: PrivilegeLevel):135self.level = level136137@property138def tag(self):139"""Escalation tag that gets added to node properties when140the escalation level is reached for that node"""141return f"privilege_{self.level}"142143144class SystemEscalation(PrivilegeEscalation):145"""Escalation to SYSTEM privileges"""146147def __init__(self):148super().__init__(PrivilegeLevel.System)149150151class AdminEscalation(PrivilegeEscalation):152"""Escalation to local administrator privileges"""153154def __init__(self):155super().__init__(PrivilegeLevel.Admin)156157158class ProbeSucceeded(VulnerabilityOutcome):159"""Probing succeeded"""160161def __init__(self, discovered_properties: List[PropertyName]):162self.discovered_properties = discovered_properties163164165class ProbeFailed(VulnerabilityOutcome):166"""Probing failed"""167168169class ExploitFailed(VulnerabilityOutcome):170"""This is for situations where the exploit fails"""171172173class CachedCredential(NamedTuple):174"""Encodes a machine-port-credential triplet"""175176node: NodeID177port: PortName178credential: CredentialID179180181class LeakedCredentials(VulnerabilityOutcome):182"""A set of credentials obtained by exploiting a vulnerability"""183184credentials: List[CachedCredential]185186def __init__(self, credentials: List[CachedCredential]):187self.credentials = credentials188189190class LeakedNodesId(VulnerabilityOutcome):191"""A set of node IDs obtained by exploiting a vulnerability"""192193def __init__(self, nodes: List[NodeID]):194self.nodes = nodes195196197VulnerabilityOutcomes = Union[LeakedCredentials, LeakedNodesId, PrivilegeEscalation, AdminEscalation, SystemEscalation, CustomerData, LateralMove, ExploitFailed]198199200class AttackResult:201"""The result of attempting a specific attack (either local or remote)"""202203success: bool204expected_outcome: Union[VulnerabilityOutcomes, None]205206207class Precondition:208"""A predicate logic expression defining the condition under which a given209feature or vulnerability is present or not.210The symbols used in the expression refer to properties associated with211the corresponding node.212E.g. 'Win7', 'Server', 'IISInstalled', 'SQLServerInstalled',213'AntivirusInstalled' ...214"""215216expression: boolean.Expression217218def __init__(self, expression: Union[boolean.Expression, str]):219if isinstance(expression, boolean.Expression):220self.expression = expression221else:222self.expression = ALGEBRA.parse(expression)223224225class VulnerabilityInfo(NamedTuple):226"""Definition of a known vulnerability"""227228# an optional description of what the vulnerability is229description: str230# type of vulnerability231type: VulnerabilityType232# what happens when successfully exploiting the vulnerability233outcome: VulnerabilityOutcome234# a boolean expression over a node's properties determining if the235# vulnerability is present or not236precondition: Precondition = Precondition("true")237# rates of success/failure associated with this vulnerability238rates: Rates = Rates()239# points to information about the vulnerability240URL: str = ""241# some cost associated with exploiting this vulnerability (e.g.242# brute force more costly than dumping credentials)243cost: float = 1.0244# a string displayed when the vulnerability is successfully exploited245reward_string: str = ""246247248# A dictionary storing information about all supported vulnerabilities249# or features supported by the simulation.250# This is to be used as a global dictionary pre-populated before251# starting the simulation and estimated from real-world data.252VulnerabilityLibrary = Dict[VulnerabilityID, VulnerabilityInfo]253254255class RulePermission(Enum):256"""Determine if a rule is blocks or allows traffic"""257258ALLOW = 0259BLOCK = 1260261262@dataclass(frozen=True)263class FirewallRule:264"""A firewall rule"""265266# A port name267port: PortName268# permission on this port269permission: RulePermission270# An optional reason for the block/allow rule271reason: str = ""272273274@dataclass275class FirewallConfiguration:276"""Firewall configuration on a given node.277Determine if traffic should be allowed or specifically blocked278on a given port for outgoing and incoming traffic.279The rules are process in order: the first rule matching a given280port is applied and the rest are ignored.281282Port that are not listed in the configuration283are assumed to be blocked. (Adding an explicit block rule284can still be useful to give a reason for the block.)285"""286287outgoing: List[FirewallRule] = field(288repr=True,289default_factory=lambda: [290FirewallRule("RDP", RulePermission.ALLOW),291FirewallRule("SSH", RulePermission.ALLOW),292FirewallRule("HTTPS", RulePermission.ALLOW),293FirewallRule("HTTP", RulePermission.ALLOW),294],295)296incoming: List[FirewallRule] = field(297repr=True,298default_factory=lambda: [299FirewallRule("RDP", RulePermission.ALLOW),300FirewallRule("SSH", RulePermission.ALLOW),301FirewallRule("HTTPS", RulePermission.ALLOW),302FirewallRule("HTTP", RulePermission.ALLOW),303],304)305306307class MachineStatus(Enum):308"""Machine running status"""309310Stopped = 0311Running = 1312Imaging = 2313314315@dataclass316class NodeInfo:317"""A computer node in the enterprise network"""318319# List of port/protocol the node is listening to320services: List[ListeningService]321# List of known vulnerabilities for the node322vulnerabilities: VulnerabilityLibrary = dataclasses.field(default_factory=dict)323# Intrinsic value of the node (translates into a reward if the node gets owned)324value: NodeValue = 0325# Properties of the nodes, some of which can imply further vulnerabilities326properties: List[PropertyName] = dataclasses.field(default_factory=list)327# Fireall configuration of the node328firewall: FirewallConfiguration = dataclasses.field(default_factory=FirewallConfiguration)329# Attacker agent installed on the node? (aka the node is 'pwned')330agent_installed: bool = False331# Esclation level332privilege_level: PrivilegeLevel = PrivilegeLevel.NoAccess333# Can the node be re-imaged by a defender agent?334reimagable: bool = True335# Last time the node was reimaged336last_reimaging: Optional[datetime] = None337# String displayed when the node gets owned338owned_string: str = ""339# Machine status: running or stopped340status = MachineStatus.Running341# Relative node weight used to calculate the cost of stopping this machine342# or its services343sla_weight: float = 1.0344345346class Identifiers(NamedTuple):347"""Define the global set of identifiers used348in the definition of a given environment.349Such set defines a common vocabulary possibly350shared across multiple environments, thus351ensuring a consistent numbering convention352that a machine learniong model can learn from."""353354# Array of all possible node property identifiers355properties: List[PropertyName] = []356# Array of all possible port names357ports: List[PortName] = ["Null"]358# Array of all possible local vulnerabilities names359local_vulnerabilities: List[VulnerabilityID] = []360# Array of all possible remote vulnerabilities names361remote_vulnerabilities: List[VulnerabilityID] = []362363364def iterate_network_nodes(network: nx.graph.Graph) -> Iterator[Tuple[NodeID, NodeInfo]]:365"""Iterates over the nodes in the network"""366for nodeid, nodevalue in network.nodes.items():367node_data: NodeInfo = nodevalue["data"]368yield nodeid, node_data369370371# NOTE: Using `NameTuple` instead of `dataclass` breaks deserialization372# with PyYaml 2.8.1 due to a new recrusive references to the networkx graph in the field373# edges: !!python/object:networkx.classes.reportviews.EdgeView374# _adjdict: *id018375# _graph: *id019376@dataclass377class Environment:378"""The static graph defining the network of computers"""379380network: nx.DiGraph381vulnerability_library: VulnerabilityLibrary382identifiers: Identifiers383creationTime: datetime = datetime.utcnow()384lastModified: datetime = datetime.utcnow()385# a version tag indicating the environment schema version386version: str = VERSION_TAG387388def nodes(self) -> Iterator[Tuple[NodeID, NodeInfo]]:389"""Iterates over the nodes in the network"""390return iterate_network_nodes(self.network)391392def get_node(self, node_id: NodeID) -> NodeInfo:393"""Retrieve info for the node with the specified ID"""394node_info: NodeInfo = self.network.nodes[node_id]["data"]395return node_info396397def plot_environment_graph(self) -> None:398"""Plot the full environment graph"""399nx.draw(self.network, with_labels=True, node_color=[n["data"].value for i, n in self.network.nodes.items()], cmap=plt.cm.Oranges) # type:ignore400401402def create_network(nodes: Dict[NodeID, NodeInfo]) -> nx.DiGraph:403"""Create a network with a set of nodes and no edges"""404graph = nx.DiGraph()405graph.add_nodes_from([(k, {"data": v}) for (k, v) in list(nodes.items())])406return graph407408409# Helpers to infer constants from an environment410411412def collect_ports_from_vuln(vuln: VulnerabilityInfo) -> List[PortName]:413"""Returns all the port named referenced in a given vulnerability"""414if isinstance(vuln.outcome, LeakedCredentials):415return [c.port for c in vuln.outcome.credentials]416else:417return []418419420def collect_vulnerability_ids_from_nodes_bytype(nodes: Iterator[Tuple[NodeID, NodeInfo]], global_vulnerabilities: VulnerabilityLibrary, type: VulnerabilityType) -> List[VulnerabilityID]:421"""Collect and return all IDs of all vulnerability of the specified type422that are referenced in a given set of nodes and vulnerability library423"""424return sorted(list({id for _, node_info in nodes for id, v in node_info.vulnerabilities.items() if v.type == type}.union(id for id, v in global_vulnerabilities.items() if v.type == type)))425426427def collect_properties_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]]) -> List[PropertyName]:428"""Collect and return sorted list of all property names used in a given set of nodes"""429return sorted({p for _, node_info in nodes for p in node_info.properties})430431432def collect_ports_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]], vulnerability_library: VulnerabilityLibrary) -> List[PortName]:433"""Collect and return all port names used in a given set of nodes434and global vulnerability library"""435return sorted(436list(437{port for _, v in vulnerability_library.items() for port in collect_ports_from_vuln(v)}.union(438{port for _, node_info in nodes for _, v in node_info.vulnerabilities.items() for port in collect_ports_from_vuln(v)}.union(439{service.name for _, node_info in nodes for service in node_info.services}440)441)442)443)444445446def collect_ports_from_environment(environment: Environment) -> List[PortName]:447"""Collect and return all port names used in a given environment"""448return collect_ports_from_nodes(environment.nodes(), environment.vulnerability_library)449450451def infer_constants_from_nodes(nodes: Iterator[Tuple[NodeID, NodeInfo]], vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:452"""Infer global environment constants from a given network"""453return Identifiers(454properties=collect_properties_from_nodes(nodes),455ports=collect_ports_from_nodes(nodes, vulnerabilities),456local_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(nodes, vulnerabilities, VulnerabilityType.LOCAL),457remote_vulnerabilities=collect_vulnerability_ids_from_nodes_bytype(nodes, vulnerabilities, VulnerabilityType.REMOTE),458)459460461def infer_constants_from_network(network: nx.Graph, vulnerabilities: Dict[VulnerabilityID, VulnerabilityInfo]) -> Identifiers:462"""Infer global environment constants from a given network"""463return infer_constants_from_nodes(iterate_network_nodes(network), vulnerabilities)464465466# Network creation467468# A sample set of envrionment constants469SAMPLE_IDENTIFIERS = Identifiers(470ports=["RDP", "SSH", "SMB", "HTTP", "HTTPS", "WMI", "SQL"], properties=["Windows", "Linux", "HyperV-VM", "Azure-VM", "Win7", "Win10", "PortRDPOpen", "GuestAccountEnabled"]471)472473474def assign_random_labels(graph: nx.DiGraph, vulnerabilities: VulnerabilityLibrary = dict([]), identifiers: Identifiers = SAMPLE_IDENTIFIERS) -> nx.DiGraph:475"""Create an envrionment network by randomly assigning node information476(properties, firewall configuration, vulnerabilities)477to the nodes of a given graph structure"""478479# convert node IDs to string480graph = nx.relabel_nodes(graph, {i: str(i) for i in graph.nodes})481482def create_random_firewall_configuration() -> FirewallConfiguration:483return FirewallConfiguration(484outgoing=[FirewallRule(port=p, permission=RulePermission.ALLOW) for p in random.sample(identifiers.ports, k=random.randint(0, len(identifiers.ports)))],485incoming=[FirewallRule(port=p, permission=RulePermission.ALLOW) for p in random.sample(identifiers.ports, k=random.randint(0, len(identifiers.ports)))],486)487488def create_random_properties() -> List[PropertyName]:489return list(random.sample(identifiers.properties, k=random.randint(0, len(identifiers.properties))))490491def pick_random_global_vulnerabilities() -> VulnerabilityLibrary:492count = random.random()493return {k: v for (k, v) in vulnerabilities.items() if random.random() > count}494495def add_leak_neighbors_vulnerability(library: VulnerabilityLibrary, node_id: NodeID) -> None:496"""Create a vulnerability for each node that reveals its immediate neighbors"""497neighbors = {t for (s, t) in graph.edges() if s == node_id}498if len(neighbors) > 0:499library["RecentlyAccessedMachines"] = VulnerabilityInfo(description="AzureVM info, including public IP address", type=VulnerabilityType.LOCAL, outcome=LeakedNodesId(list(neighbors)))500501def create_random_vulnerabilities(node_id: NodeID) -> VulnerabilityLibrary:502library = pick_random_global_vulnerabilities()503add_leak_neighbors_vulnerability(library, node_id)504return library505506# Pick a random node as the agent entry node507entry_node_index = random.randrange(len(graph.nodes))508entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]509graph.nodes[entry_node_id].clear()510node_data = NodeInfo(511services=[],512value=0,513properties=create_random_properties(),514vulnerabilities=create_random_vulnerabilities(entry_node_id),515firewall=create_random_firewall_configuration(),516agent_installed=True,517reimagable=False,518privilege_level=PrivilegeLevel.Admin,519)520graph.nodes[entry_node_id].update({"data": node_data})521522def create_random_node_data(node_id: NodeID) -> NodeInfo:523return NodeInfo(524services=[],525value=random.randint(0, 100),526properties=create_random_properties(),527vulnerabilities=create_random_vulnerabilities(node_id),528firewall=create_random_firewall_configuration(),529agent_installed=False,530privilege_level=PrivilegeLevel.NoAccess,531)532533for node in list(graph.nodes):534if node != entry_node_id:535graph.nodes[node].clear()536graph.nodes[node].update({"data": create_random_node_data(node)})537538return graph539540541# Serialization542543544def setup_yaml_serializer() -> None:545"""Setup a clean YAML formatter for object of type Environment."""546yaml.add_representer(Precondition, lambda dumper, data: dumper.represent_scalar("!BooleanExpression", str(data.expression))) # type: ignore547yaml.SafeLoader.add_constructor("!BooleanExpression", lambda loader, expression: Precondition(loader.construct_scalar(expression))) # type: ignore548yaml.add_constructor("!BooleanExpression", lambda loader, expression: Precondition(loader.construct_scalar(expression))) # type: ignore549550yaml.add_representer(VulnerabilityType, lambda dumper, data: dumper.represent_scalar("!VulnerabilityType", str(data.name))) # type: ignore551552yaml.SafeLoader.add_constructor("!VulnerabilityType", lambda loader, expression: VulnerabilityType[loader.construct_scalar(expression)]) # type: ignore553yaml.add_constructor("!VulnerabilityType", lambda loader, expression: VulnerabilityType[loader.construct_scalar(expression)]) # type: ignore554555556