Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/graph_spaces.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
from typing import Optional
5
6
import networkx as nx
7
from gymnasium.spaces import Space, Dict
8
9
10
class BaseGraph(Space):
11
_nx_class: type
12
13
def __init__(
14
self,
15
max_num_nodes: int,
16
node_property_space: Optional[Dict] = None,
17
edge_property_space: Optional[Dict] = None,
18
):
19
self.max_num_nodes = max_num_nodes
20
self.node_property_space = Dict() if node_property_space is None else node_property_space
21
self.edge_property_space = Dict() if edge_property_space is None else edge_property_space
22
super().__init__()
23
24
def sample(self, mask=None):
25
num_nodes = self.np_random.integers(0, self.max_num_nodes + 1)
26
graph = self._nx_class()
27
28
# add nodes with properties
29
for node_id in range(num_nodes):
30
node_properties = {k: s.sample() for k, s in self.node_property_space.spaces.items()}
31
graph.add_node(node_id, **node_properties)
32
33
if num_nodes < 2:
34
return graph
35
36
# add some edges with properties
37
seen, unseen = [], list(range(num_nodes)) # init
38
self.__pop_random(unseen, seen) # pop one node before we start
39
while unseen:
40
node_id_from, node_id_to = (
41
self.__sample_random(seen),
42
self.__pop_random(unseen, seen),
43
)
44
edge_properties = {k: s.sample() for k, s in self.edge_property_space.spaces.items()}
45
graph.add_edge(node_id_from, node_id_to, **edge_properties)
46
47
return graph
48
49
def __pop_random(self, unseen: list, seen: list):
50
i = self.np_random.choice(len(unseen))
51
x = unseen[i]
52
seen.append(x)
53
del unseen[i]
54
return x
55
56
def __sample_random(self, seen: list):
57
i = self.np_random.choice(len(seen))
58
return seen[i]
59
60
def contains(self, x):
61
return (
62
isinstance(x, self._nx_class)
63
and all(node_property in self.node_property_space for node_property in x.nodes.values())
64
and all(edge_property in self.edge_property_space for edge_property in x.edges.values())
65
)
66
67
68
class Graph(BaseGraph):
69
_nx_class = nx.Graph
70
71
72
class DiGraph(BaseGraph):
73
_nx_class = nx.DiGraph
74
75
76
class MultiGraph(BaseGraph):
77
_nx_class = nx.MultiGraph
78
79
80
class MultiDiGraph(BaseGraph):
81
_nx_class = nx.MultiDiGraph
82
83
84
if __name__ == "__main__":
85
from gymnasium.spaces import Box, Discrete
86
import matplotlib.pyplot as plt # type:ignore
87
88
space = DiGraph(
89
max_num_nodes=10,
90
node_property_space=Dict({"vector": Box(0, 1, (3,)), "category": Discrete(7)}),
91
edge_property_space=Dict({"weight": Box(0, 1, ())}),
92
)
93
94
space.seed(42)
95
graph = space.sample()
96
assert graph in space
97
98
for node_id, node_properties in graph.nodes.items():
99
print(f"node_id: {node_id}, node_properties: {node_properties}")
100
101
for (node_id_from, node_id_to), edge_properties in graph.edges.items():
102
print(f"node_id_from: {node_id_from}, node_id_to: {node_id_to}, " f"edge_properties: {edge_properties}")
103
104
pos = nx.spring_layout(graph)
105
nx.draw_networkx_nodes(graph, pos)
106
nx.draw_networkx_edges(graph, pos)
107
nx.draw_networkx_labels(graph, pos)
108
# nx.draw_networkx_labels(graph, pos, graph.nodes)
109
plt.show()
110
111