import copy
import json
from tabulate import tabulate
import numpy
import warnings
import itertools
from coopihc.base.StateElement import StateElement
from coopihc.base.utils import (
NotKnownSerializationWarning,
StateElementAssignmentWarning,
)
from coopihc.base.Space import CatSet
[docs]class State(dict):
"""State
The container class for States. State subclasses dictionnary and adds a few methods:
* reset(dic = reset_dic), which passes reset values to the StateElements it holds and triggers their reset method
* filter(mode=mode, filterdict=filterdict), which filters out the state to extract some information.
* serialize(), which transforms the state into a format that can be serializable, e.g. to send as a JSON format.
Initializing a State is straightforward:
.. code-block:: python
state = State()
substate = State()
substate["x1"] = discrete_array_element(init=1, low=1, high=3)
substate["x3"] = array_element(
init=1.5 * numpy.ones((2, 2)), low=numpy.ones((2, 2)), high=2 * numpy.ones((2, 2))
)
substate2 = State()
substate2["y1"] = discrete_array_element(init=1, low=1, high=3)
state["sub1"] = substate
state["sub2"] = substate2
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __eq__(self, other):
"""__eq__
equality checks on arrays (soft) à la Numpy
.. code-block:: python
_example_state = example_game_state()
obs = {
"game_info": {"turn_index": numpy.array(0), "round_index": numpy.array(0)},
"task_state": {"position": numpy.array(2), "targets": numpy.array([0, 1])},
"user_action": {"action": numpy.array(0)},
"assistant_action": {"action": numpy.array(2)},
}
del _example_state["user_state"]
del _example_state["assistant_state"]
assert _example_state == obs
assert _example_state.equals(obs, mode="soft")
"""
return self.equals(other, mode="soft")
[docs] def equals(self, other, mode="hard"):
"""equals
equality checks that also checks for spaces (hard).
.. code-block:: python
_example_state = example_game_state()
obs = {
"game_info": {"turn_index": numpy.array(0), "round_index": numpy.array(0)},
"task_state": {"position": numpy.array(2), "targets": numpy.array([0, 1])},
"user_action": {"action": numpy.array(0)},
"assistant_action": {"action": numpy.array(2)},
}
del _example_state["user_state"]
del _example_state["assistant_state"]
assert not _example_state.equals(obs, mode="hard")
"""
for (key, value), (okey, ovalue) in itertools.zip_longest(
self.items(), other.items()
):
cond = value.equals(ovalue, mode=mode)
if not isinstance(cond, bool):
try:
cond = cond.all()
except:
cond = all(cond)
if not (key == okey and cond):
return False
return True
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setitem__(self, key, value):
if isinstance(value, (State, StateElement)):
if isinstance(value, StateElement):
try:
if self[key].space != value.space:
warnings.warn(
StateElementAssignmentWarning(
f"You are trying to assign StateElement {value} with space {value.space} to a state which has previous StateElement {self[key]} with space {self[key].space}. To suppress this warning, either make sure your assignment is not of type StateElement, or delete the old statelement beforehand if you want it to be replaced"
)
)
self[key][...] = value[...]
return
except KeyError:
return super().__setitem__(key, value)
return super().__setitem__(key, value)
try:
self[key][...] = value
return
except KeyError:
return super().__setitem__(key, value)
[docs] def reset(self, dic={}):
"""Initialize the state. See StateElement
Example usage:
.. code-block:: python
# Normal reset
state.reset()
# Forced reset
reset_dic = {
"sub1": {"x1": 3},
"sub2": {"y1": 3},
}
state.reset(dic=reset_dic)
"""
for key, value in self.items():
reset_dic = dic.get(key)
value.reset(reset_dic)
[docs] def filter(self, mode="array", filterdict=None):
"""Extract some part of the state information
An example for filterdict's structure is as follows:
.. code-block:: python
filterdict = dict(
{
"sub1": dict({"x1": 0, "x3": slice(0, 1)}),
"sub2": dict({"y1": 0}),
}
)
This will filter out
* the first component (index 0) for subsubstate x1 in substate sub1,
* the first and second components for subsubstate x3 in substate sub1,
* the first component for subsubstate y1 in substate sub2.
Example usage:
.. code-block:: python
# Filter out spaces
f_state = state.filter(mode="space", filterdict=filterdict)
# Filter out as arrays
f_state = state.filter(mode="array", filterdict=filterdict)
# Filter out as StateElement
f_state = state.filter(mode="stateelement", filterdict=filterdict)
# Get spaces
f_state = state.filter(mode="space")
# Get arrays
f_state = state.filter(mode="array")
# Get Gym Compatible arrays
f_state = state.filter(mode="array-Gym")
:param mode: "array" or "spaces" or "stateelement", defaults to "array". If "stateelement", returns a dictionnary with the selected stateelements. If "spaces", returns the same dictionnary, but with only the spaces (no array information). If "array", returns the same dictionnary, but with only the value arrays (no space information).
:type mode: str, optional
:param filterdict: the dictionnary that indicates which components to filter out, defaults to None
:type filterdict: dictionnary, optional
:return: The dictionnary with the filtered state
:rtype: dictionnary
"""
new_state = {}
if filterdict is None:
filterdict = self
for key, value in filterdict.items():
if isinstance(self[key], State):
new_state[key] = self[key].filter(mode=mode, filterdict=value)
elif isinstance(self[key], StateElement):
# to make S.filter("values", S) possible.
# Warning: values == filterdict[key] != self[key]
if isinstance(value, StateElement):
try:
value = slice(0, len(value), 1)
except TypeError: # Deal with 0-D arrays
value = slice(0, 1, 1) # Ellipsis # slice(0, 1, 1)
if mode == "space":
_SEspace = self[key].space
new_state[key] = _SEspace
elif mode == "array":
new_state[key] = (self[key][value]).view(numpy.ndarray)
elif mode == "array-Gym":
v = (self[key][value]).view(numpy.ndarray)
if isinstance(self[key].space, CatSet):
new_state[key] = int(v)
elif self[key].space.shape == ():
new_state[key] = numpy.atleast_1d(v)
else:
new_state[key] = v
elif mode == "stateelement":
new_state[key] = self[key][value, {"space": True}]
else:
new_state[key] = self[key]
return new_state
def __content__(self):
return list(self.keys())
# Here we override copy and deepcopy simply because there seems to be some overhead in the default deepcopy implementation. It turns out the gain is almost None, but keep it here as a reminder that deepcopy needs speeding up. Adapted from StateElement code
def __copy__(self):
cls = self.__class__
copy_object = cls.__new__(cls)
copy_object.__dict__.update(self.__dict__)
copy_object.update(self)
return copy_object
def __deepcopy__(self, memodict={}):
cls = self.__class__
deepcopy_object = cls.__new__(cls)
memodict[id(self)] = deepcopy_object
deepcopy_object.__dict__.update(self.__dict__)
for k, v in self.items():
deepcopy_object[k] = copy.deepcopy(v, memodict)
return deepcopy_object
[docs] def serialize(self):
"""Makes the state serializable.
.. code-block:: python
state.serialize()
:return: serializable dictionnary
:rtype: dict
"""
ret_dict = {}
for key, value in dict(self).items():
try:
value_ = json.dumps(value)
except TypeError:
try:
value_ = value.serialize()
except AttributeError:
warnings.warns(
NotKnownSerializationWarning(
"warning: I don't know how to serialize {}. I'm sending the whole internal dictionnary of the object. Consider adding a serialize() method to your custom object".format(
value.__str__()
)
)
)
value_ = value.__dict__
ret_dict[key] = value_
return ret_dict
def _tabulate(self):
"""_tabulate
See __str__ for usage
"""
table = []
line_no = 0
for n, (key, value) in enumerate(self.items()):
tab, tablines = value._tabulate()
nline = 1 # deal with empty substates
for nline, line in enumerate(tab):
if isinstance(value, State) and nline != 0:
key = " "
line.insert(0, key)
table.extend(tab)
line_no += (n + 1) * (nline + 1)
return (table, line_no)
def __str__(self):
return tabulate(self._tabulate()[0])