Path: blob/main/cyberbattle/_env/discriminatedunion.py
597 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""A discriminated union space for Gym"""45from collections import OrderedDict6from typing import Any, Mapping, Optional, Sequence, TypeVar, Union7from typing import Dict as TypingDict, Generic, cast8import numpy as np910from gymnasium import spaces11from gymnasium.utils import seeding1213T_cov = TypeVar("T_cov", covariant=True)141516class DiscriminatedUnion(spaces.Dict, Generic[T_cov]):17"""18A discriminated union of simpler spaces.1920Example usage:2122self.observation_space = discriminatedunion.DiscriminatedUnion(23{"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)})2425Generic type T_cov is the type of the contained discriminated values.26It should be defined as a typed dictionary, e.g.: TypedDict('Choices', {'foo': int, 'Bar': int})2728"""2930def __init__(31self,32spaces: Union[None, TypingDict[str, spaces.Space]] = None,33seed: Optional[Union[dict, int, np.random.Generator]] = None,34**spaces_kwargs: spaces.Space,35) -> None:36"""Create a discriminated union space"""37if spaces is None:38super().__init__(spaces_kwargs)39else:40super().__init__(spaces=spaces, seed=seed)4142if isinstance(seed, dict):43self.union_np_random, _ = seeding.np_random(None)44elif isinstance(seed, np.random.Generator):45self.union_np_random = seed46else:47self.union_np_random, _ = seeding.np_random(seed)4849def seed(self, seed: Union[dict, None, int] = None):50return super().seed(seed)5152def sample(self, mask=None) -> T_cov: # type: ignore53space_count = len(self.spaces.items())54index_k = self.union_np_random.integers(0, space_count)55kth_key, kth_space = list(self.spaces.items())[index_k]56return cast(T_cov, OrderedDict([(kth_key, kth_space.sample())]))5758def contains(self, x) -> bool:59if not isinstance(x, dict) or len(x) != 1:60return False61k, space = list(x)[0]62return k in self.spaces.keys()6364@classmethod65def is_of_kind(cls, key: str, sample_n: Mapping[str, object]) -> bool:66"""Returns true if a given sample is of the specified discriminated kind"""67return key in sample_n.keys()6869@classmethod70def kind(cls, sample_n: Mapping[str, object]) -> str:71"""Returns the discriminated kind of a given sample"""72keys = sample_n.keys()73assert len(keys) == 174return list(keys)[0]7576def __getitem__(self, key: str) -> spaces.Space:77return self.spaces[key]7879def __repr__(self) -> str:80return self.__class__.__name__ + "(" + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"8182def to_jsonable(self, sample_n: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:83return super().to_jsonable(sample_n)8485def from_jsonable(self, sample_n: TypingDict[str, list]) -> list[OrderedDict[str, Any]]:86ret = super().from_jsonable(sample_n)87assert len(ret) == 188return ret8990def __eq__(self, other: object) -> bool:91return isinstance(other, DiscriminatedUnion) and self.spaces == other.spaces929394def test_sampling() -> None:95"""Simple sampling test"""96union = DiscriminatedUnion(spaces={"foo": spaces.Discrete(8), "Bar": spaces.Discrete(3)})97[union.sample() for i in range(100)]9899100