Source code for coopihc.base.elements

from xml.dom.minidom import Attr
from coopihc.base.Space import Space
from coopihc.base.StateElement import StateElement
from coopihc.base.State import State

import numpy


# def _numpy_max_info(dtype):
#     if numpy.issubdtype(dtype, numpy.integer):
#         return numpy.iinfo(dtype).max
#     else:
#         return numpy.finfo(dtype).max


# ======================== Space Shortcuts ========================
# def lin_space(start, stop, num=50, endpoint=True, dtype=numpy.int64, seed=None, **kwargs):
#     # lin_space(num=50, start=0, stop=None, endpoint=False, dtype=numpy.int64):
#     """Linearly spaced discrete space.

#     Wrap numpy's linspace to produce a space that is compatible with COOPIHC. Parameters of this function are defined in https://numpy.org/doc/stable/reference/generated/numpy.linspace.html


#     :return: _description_
#     :rtype: _type_
#     """
#     if stop is None:
#         stop = num + start
#     return Space(
#         array=numpy.linspace(
#             start, stop, num=num, endpoint=endpoint, dtype=dtype
#         ),
#         seed=seed,
#         **kwargs
#     )


[docs]def integer_set(N, dtype=numpy.int64, **kwargs): """{0, 1, ... , N-1} Set Wrapper around lin_space. :param N: cardinality of set :type N: int :return: Integer Set :rtype: CatSet """ return Space( array=numpy.linspace(0, N, num=N, endpoint=False), dtype=dtype, **kwargs )
[docs]def integer_space(N=None, start=0, stop=None, dtype=numpy.int64): """[0, 1, ... , N-1, N] Wrapper around box_space :param N: upper bound of discrete interval :type N: integer :return: Integer Space :rtype: Numeric """ if stop is None: if N is None: N = numpy.iinfo(dtype).max stop = N + start - 1 return box_space(low=numpy.array(start), high=numpy.array(stop), dtype=dtype)
[docs]def box_space(high=numpy.ones((1, 1)), low=None, **kwargs): """[low, high] Numeric _extended_summary_ :param high: _description_, defaults to numpy.ones((1, 1)) :type high: _type_, optional :param low: _description_, defaults to None :type low: _type_, optional :return: _description_ :rtype: _type_ """ if low is None: low = -high return Space(low=low, high=high, **kwargs)
# ======================== StateElement Shortcuts ======================== def _get_shape_from_objects(shape, *obj): _shape = [shape] for _o in obj: try: _shape.append(_o.shape) except AttributeError: # has no shape _shape.append(numpy.asarray(_o).shape) return sorted([x for x in _shape if x is not None], key=len)[-1] def _set_shape_object(shape, obj, default_value): if obj is None: obj = numpy.full(shape, default_value) else: obj = numpy.asarray(obj) try: obj = obj.reshape(shape) except ValueError: obj = numpy.full(shape, obj) return obj
[docs]def array_element( shape=None, init=None, low=None, high=None, out_of_bounds_mode="warning", **kwargs ): shape = _get_shape_from_objects(shape, init, low, high) init = _set_shape_object(shape, init, 0) low = _set_shape_object(shape, low, -numpy.inf) high = _set_shape_object(shape, high, numpy.inf) return StateElement( init, Space(low=low, high=high, **kwargs), out_of_bounds_mode=out_of_bounds_mode, )
[docs]def discrete_array_element( N=None, shape=None, init=0, low=None, high=None, dtype=None, **kwargs ): if dtype is None: dtype = numpy.int64 if N is None: return array_element( shape=shape, init=init, low=low, high=high, dtype=dtype, **kwargs ) else: if low is None: low = 0 if high is None: high = N - 1 + low return array_element( shape=shape, init=init, low=low, high=high, dtype=dtype, **kwargs )
[docs]def cat_element(N, init=0, out_of_bounds_mode="warning", **kwargs): return StateElement( init, integer_set(N, **kwargs), out_of_bounds_mode=out_of_bounds_mode )
[docs]def example_game_state(): return State( game_info=State( turn_index=cat_element( N=4, init=0, out_of_bounds_mode="raw", dtype=numpy.int8 ), round_index=discrete_array_element(init=0, out_of_bounds_mode="raw"), ), task_state=State( position=discrete_array_element(N=4, init=2, out_of_bounds_mode="clip"), targets=discrete_array_element( init=numpy.array([0, 1]), low=numpy.array([0, 0]), high=numpy.array([3, 3]), ), ), user_state=State(goal=discrete_array_element(N=4)), assistant_state=State( beliefs=array_element( init=numpy.array([1 / 8 for i in range(8)]), low=numpy.zeros((8,)), high=numpy.ones((8,)), ) ), user_action=State(action=discrete_array_element(low=-1, high=1)), assistant_action=State( action=cat_element(4, init=2, out_of_bounds_mode="error") ), )