Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/CyberBattleSim
Path: blob/main/cyberbattle/_env/discriminatedunion.py
597 views
1
# Copyright (c) Microsoft Corporation.
2
# Licensed under the MIT License.
3
4
"""A discriminated union space for Gym"""
5
6
from collections import OrderedDict
7
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
8
from typing import Dict as TypingDict, Generic, cast
9
import numpy as np
10
11
from gymnasium import spaces
12
from gymnasium.utils import seeding
13
14
T_cov = TypeVar("T_cov", covariant=True)
15
16
17
class DiscriminatedUnion(spaces.Dict, Generic[T_cov]):
18
"""
19
A discriminated union of simpler spaces.
20
21
Example usage:
22
23
self.observation_space = discriminatedunion.DiscriminatedUnion(
24
{"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)})
25
26
Generic type T_cov is the type of the contained discriminated values.
27
It should be defined as a typed dictionary, e.g.: TypedDict('Choices', {'foo': int, 'Bar': int})
28
29
"""
30
31
def __init__(
32
self,
33
spaces: Union[None, TypingDict[str, spaces.Space]] = None,
34
seed: Optional[Union[dict, int, np.random.Generator]] = None,
35
**spaces_kwargs: spaces.Space,
36
) -> None:
37
"""Create a discriminated union space"""
38
if spaces is None:
39
super().__init__(spaces_kwargs)
40
else:
41
super().__init__(spaces=spaces, seed=seed)
42
43
if isinstance(seed, dict):
44
self.union_np_random, _ = seeding.np_random(None)
45
elif isinstance(seed, np.random.Generator):
46
self.union_np_random = seed
47
else:
48
self.union_np_random, _ = seeding.np_random(seed)
49
50
def seed(self, seed: Union[dict, None, int] = None):
51
return super().seed(seed)
52
53
def sample(self, mask=None) -> T_cov: # type: ignore
54
space_count = len(self.spaces.items())
55
index_k = self.union_np_random.integers(0, space_count)
56
kth_key, kth_space = list(self.spaces.items())[index_k]
57
return cast(T_cov, OrderedDict([(kth_key, kth_space.sample())]))
58
59
def contains(self, x) -> bool:
60
if not isinstance(x, dict) or len(x) != 1:
61
return False
62
k, space = list(x)[0]
63
return k in self.spaces.keys()
64
65
@classmethod
66
def is_of_kind(cls, key: str, sample_n: Mapping[str, object]) -> bool:
67
"""Returns true if a given sample is of the specified discriminated kind"""
68
return key in sample_n.keys()
69
70
@classmethod
71
def kind(cls, sample_n: Mapping[str, object]) -> str:
72
"""Returns the discriminated kind of a given sample"""
73
keys = sample_n.keys()
74
assert len(keys) == 1
75
return list(keys)[0]
76
77
def __getitem__(self, key: str) -> spaces.Space:
78
return self.spaces[key]
79
80
def __repr__(self) -> str:
81
return self.__class__.__name__ + "(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
82
83
def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
84
return super().to_jsonable(sample_n)
85
86
def from_jsonable(self, sample_n: TypingDict[str, list]) -> list[OrderedDict[str, Any]]:
87
ret = super().from_jsonable(sample_n)
88
assert len(ret) == 1
89
return ret
90
91
def __eq__(self, other: object) -> bool:
92
return isinstance(other, DiscriminatedUnion) and self.spaces == other.spaces
93
94
95
def test_sampling() -> None:
96
"""Simple sampling test"""
97
union = DiscriminatedUnion(spaces={"foo": spaces.Discrete(8), "Bar": spaces.Discrete(3)})
98
[union.sample() for i in range(100)]
99
100