Path: blob/main/cyberbattle/_env/graph_spaces.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23from typing import Optional45import networkx as nx6from gymnasium.spaces import Space, Dict789class BaseGraph(Space):10_nx_class: type1112def __init__(13self,14max_num_nodes: int,15node_property_space: Optional[Dict] = None,16edge_property_space: Optional[Dict] = None,17):18self.max_num_nodes = max_num_nodes19self.node_property_space = Dict() if node_property_space is None else node_property_space20self.edge_property_space = Dict() if edge_property_space is None else edge_property_space21super().__init__()2223def sample(self, mask=None):24num_nodes = self.np_random.integers(0, self.max_num_nodes + 1)25graph = self._nx_class()2627# add nodes with properties28for node_id in range(num_nodes):29node_properties = {k: s.sample() for k, s in self.node_property_space.spaces.items()}30graph.add_node(node_id, **node_properties)3132if num_nodes < 2:33return graph3435# add some edges with properties36seen, unseen = [], list(range(num_nodes)) # init37self.__pop_random(unseen, seen) # pop one node before we start38while unseen:39node_id_from, node_id_to = (40self.__sample_random(seen),41self.__pop_random(unseen, seen),42)43edge_properties = {k: s.sample() for k, s in self.edge_property_space.spaces.items()}44graph.add_edge(node_id_from, node_id_to, **edge_properties)4546return graph4748def __pop_random(self, unseen: list, seen: list):49i = self.np_random.choice(len(unseen))50x = unseen[i]51seen.append(x)52del unseen[i]53return x5455def __sample_random(self, seen: list):56i = self.np_random.choice(len(seen))57return seen[i]5859def contains(self, x):60return (61isinstance(x, self._nx_class)62and all(node_property in self.node_property_space for node_property in x.nodes.values())63and all(edge_property in self.edge_property_space for edge_property in x.edges.values())64)656667class Graph(BaseGraph):68_nx_class = nx.Graph697071class DiGraph(BaseGraph):72_nx_class = nx.DiGraph737475class MultiGraph(BaseGraph):76_nx_class = nx.MultiGraph777879class MultiDiGraph(BaseGraph):80_nx_class = nx.MultiDiGraph818283if __name__ == "__main__":84from gymnasium.spaces import Box, Discrete85import matplotlib.pyplot as plt # type:ignore8687space = DiGraph(88max_num_nodes=10,89node_property_space=Dict({"vector": Box(0, 1, (3,)), "category": Discrete(7)}),90edge_property_space=Dict({"weight": Box(0, 1, ())}),91)9293space.seed(42)94graph = space.sample()95assert graph in space9697for node_id, node_properties in graph.nodes.items():98print(f"node_id: {node_id}, node_properties: {node_properties}")99100for (node_id_from, node_id_to), edge_properties in graph.edges.items():101print(f"node_id_from: {node_id_from}, node_id_to: {node_id_to}, " f"edge_properties: {edge_properties}")102103pos = nx.spring_layout(graph)104nx.draw_networkx_nodes(graph, pos)105nx.draw_networkx_edges(graph, pos)106nx.draw_networkx_labels(graph, pos)107# nx.draw_networkx_labels(graph, pos, graph.nodes)108plt.show()109110111