from xml.dom.minidom import Attr
from coopihc.base.Space import Space
from coopihc.base.StateElement import StateElement
from coopihc.base.State import State
import numpy
# def _numpy_max_info(dtype):
#     if numpy.issubdtype(dtype, numpy.integer):
#         return numpy.iinfo(dtype).max
#     else:
#         return numpy.finfo(dtype).max
# ======================== Space Shortcuts ========================
# def lin_space(start, stop, num=50, endpoint=True, dtype=numpy.int64, seed=None, **kwargs):
#     # lin_space(num=50, start=0, stop=None, endpoint=False, dtype=numpy.int64):
#     """Linearly spaced discrete space.
#     Wrap numpy's linspace to produce a space that is compatible with COOPIHC. Parameters of this function are defined in https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
#     :return: _description_
#     :rtype: _type_
#     """
#     if stop is None:
#         stop = num + start
#     return Space(
#         array=numpy.linspace(
#             start, stop, num=num, endpoint=endpoint, dtype=dtype
#         ),
#         seed=seed,
#         **kwargs
#     )
[docs]def integer_set(N, dtype=numpy.int64, **kwargs):
    """{0, 1, ... , N-1} Set
    Wrapper around lin_space.
    :param N: cardinality of set
    :type N: int
    :return: Integer Set
    :rtype: CatSet
    """
    return Space(
        array=numpy.linspace(0, N, num=N, endpoint=False), dtype=dtype, **kwargs
    ) 
[docs]def integer_space(N=None, start=0, stop=None, dtype=numpy.int64):
    """[0, 1, ... , N-1, N]
    Wrapper around box_space
    :param N: upper bound of discrete interval
    :type N: integer
    :return: Integer Space
    :rtype: Numeric
    """
    if stop is None:
        if N is None:
            N = numpy.iinfo(dtype).max
        stop = N + start - 1
    return box_space(low=numpy.array(start), high=numpy.array(stop), dtype=dtype) 
[docs]def box_space(high=numpy.ones((1, 1)), low=None, **kwargs):
    """[low, high] Numeric
    _extended_summary_
    :param high: _description_, defaults to numpy.ones((1, 1))
    :type high: _type_, optional
    :param low: _description_, defaults to None
    :type low: _type_, optional
    :return: _description_
    :rtype: _type_
    """
    if low is None:
        low = -high
    return Space(low=low, high=high, **kwargs) 
# ======================== StateElement Shortcuts ========================
def _get_shape_from_objects(shape, *obj):
    _shape = [shape]
    for _o in obj:
        try:
            _shape.append(_o.shape)
        except AttributeError:  # has no shape
            _shape.append(numpy.asarray(_o).shape)
    return sorted([x for x in _shape if x is not None], key=len)[-1]
def _set_shape_object(shape, obj, default_value):
    if obj is None:
        obj = numpy.full(shape, default_value)
    else:
        obj = numpy.asarray(obj)
        try:
            obj = obj.reshape(shape)
        except ValueError:
            obj = numpy.full(shape, obj)
    return obj
[docs]def array_element(
    shape=None, init=None, low=None, high=None, out_of_bounds_mode="warning", **kwargs
):
    shape = _get_shape_from_objects(shape, init, low, high)
    init = _set_shape_object(shape, init, 0)
    low = _set_shape_object(shape, low, -numpy.inf)
    high = _set_shape_object(shape, high, numpy.inf)
    return StateElement(
        init,
        Space(low=low, high=high, **kwargs),
        out_of_bounds_mode=out_of_bounds_mode,
    ) 
[docs]def discrete_array_element(
    N=None, shape=None, init=0, low=None, high=None, dtype=None, **kwargs
):
    if dtype is None:
        dtype = numpy.int64
    if N is None:
        return array_element(
            shape=shape, init=init, low=low, high=high, dtype=dtype, **kwargs
        )
    else:
        if low is None:
            low = 0
        if high is None:
            high = N - 1 + low
        return array_element(
            shape=shape, init=init, low=low, high=high, dtype=dtype, **kwargs
        ) 
[docs]def cat_element(N, init=0, out_of_bounds_mode="warning", **kwargs):
    return StateElement(
        init, integer_set(N, **kwargs), out_of_bounds_mode=out_of_bounds_mode
    ) 
[docs]def example_game_state():
    return State(
        game_info=State(
            turn_index=cat_element(
                N=4, init=0, out_of_bounds_mode="raw", dtype=numpy.int8
            ),
            round_index=discrete_array_element(init=0, out_of_bounds_mode="raw"),
        ),
        task_state=State(
            position=discrete_array_element(N=4, init=2, out_of_bounds_mode="clip"),
            targets=discrete_array_element(
                init=numpy.array([0, 1]),
                low=numpy.array([0, 0]),
                high=numpy.array([3, 3]),
            ),
        ),
        user_state=State(goal=discrete_array_element(N=4)),
        assistant_state=State(
            beliefs=array_element(
                init=numpy.array([1 / 8 for i in range(8)]),
                low=numpy.zeros((8,)),
                high=numpy.ones((8,)),
            )
        ),
        user_action=State(action=discrete_array_element(low=-1, high=1)),
        assistant_action=State(
            action=cat_element(4, init=2, out_of_bounds_mode="error")
        ),
    )