Path: blob/main/cyberbattle/simulation/environment_generation_test.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""4The unit tests for the environment_generation functions5"""67from collections import Counter8from cyberbattle.simulation import commandcontrol9from typing import List, Dict10import pytest11from . import environment_generation12from . import model1314windows_vulns: Dict[str, model.VulnerabilityInfo] = environment_generation.potential_windows_vulns15linux_vulns: Dict[str, model.VulnerabilityInfo] = environment_generation.potential_linux_vulns1617windows_node_states: List[model.PropertyName] = environment_generation.potential_linux_node_states18linux_node_states: List[model.PropertyName] = environment_generation.potential_linux_node_states1920potential_ports: List[model.PortName] = environment_generation.potential_ports212223def test_create_random_environment() -> None:24"""25The unit tests for create_random_environment function26"""27with pytest.raises(ValueError, match=r"Please supply a non empty string for the name"):28environment_generation.create_random_environment("", 2)2930with pytest.raises(ValueError, match=r"Please supply a positive non zero positive" r"integer for the size of the environment"):31environment_generation.create_random_environment("Test_environment", -5)3233result: model.Environment = environment_generation.create_random_environment("Test_environment 2", 4)34assert isinstance(result, model.Environment)353637def test_random_environment_list_attacks() -> None:38"""39Unit tests for #23 caused by bug https://github.com/bastikr/boolean.py/issues/82 in boolean.py40"""41env = environment_generation.create_random_environment("test", 10)42c2 = commandcontrol.CommandControl(env)43c2.print_all_attacks()444546def test_create_random_node() -> None:47"""48The unit tests for create_random_node() function49"""5051# check that the correct exceptions are generated52with pytest.raises(ValueError, match=r"No endpoints supplied"):53environment_generation.create_random_node("Linux", [])5455with pytest.raises(ValueError, match=r"Unsupported OS Type please enter Linux or Windows"):56environment_generation.create_random_node("Solaris", potential_ports)5758test_node: model.NodeInfo = environment_generation.create_random_node("Linux", potential_ports)5960assert isinstance(test_node, model.NodeInfo)616263def test_get_properties_from_vulnerabilities() -> None:64"""65This function tests the get_properties_from_vulnerabilities function66It takes nothing and returns nothing.67"""68# testing on linux vulns69props: List[model.PropertyName] = environment_generation.get_properties_from_vulnerabilities("Linux", linux_vulns)70assert "Linux" in props71assert "PortSSHOpen" in props72assert "PortSMBOpen" in props7374# testing on Windows vulns75windows_props: List[model.PropertyName] = environment_generation.get_properties_from_vulnerabilities("Windows", windows_vulns)76assert "Windows" in windows_props77assert "PortRDPOpen" in windows_props78assert "PortSMBOpen" in windows_props79assert "DomainJoined" in windows_props80assert "Win10" in windows_props81assert "Win7" in windows_props828384def test_create_firewall_rules() -> None:85"""86This function tests the create_firewall_rules function.87It takes nothing and returns nothing.88"""89empty_ports: List[model.PortName] = []90potential_port_list: List[model.PortName] = ["RDP", "SSH", "HTTP", "HTTPs", "SMB", "SQL", "FTP", "WMI"]91half_ports: List[model.PortName] = ["SSH", "HTTPs", "SQL", "FTP", "WMI"]92all_blocked: List[model.FirewallRule] = [model.FirewallRule(port, model.RulePermission.BLOCK) for port in potential_port_list]93all_allowed: List[model.FirewallRule] = [model.FirewallRule(port, model.RulePermission.ALLOW) for port in potential_port_list]94half_allowed: List[model.FirewallRule] = [95model.FirewallRule(port, model.RulePermission.ALLOW) if port in half_ports else model.FirewallRule(port, model.RulePermission.BLOCK) for port in potential_port_list96]9798# testing on an empty list should lead to99results: model.FirewallConfiguration = environment_generation.create_firewall_rules(empty_ports)100assert Counter(results.incoming) == Counter(all_blocked)101assert Counter(results.outgoing) == Counter(all_blocked)102# testing on a the list supported ports103results = environment_generation.create_firewall_rules(potential_ports)104assert Counter(results.incoming) == Counter(all_allowed)105assert Counter(results.outgoing) == Counter(all_allowed)106107results = environment_generation.create_firewall_rules(half_ports)108assert Counter(results.incoming) == Counter(half_allowed)109assert Counter(results.outgoing) == Counter(half_allowed)110111112