StateElement

A StateElement is a a combination of a value and a space. Under the hood, StateElement subclasses numpy.ndarray; essentially, it adds a layer that checks whether the values are contained in the space (and what to do if not). As a result, many NumPy methods will work.

Instantiating a StateElement is straightforward.

1# Value in a Categorical set
2x = cat_element(3)
3# Discrete value
4y = discrete_array_element(low=2, high=5, init=2)
5# Continuous Value
6z = array_element(
7    low=-numpy.ones((2, 2)), high=numpy.ones((2, 2)), out_of_bounds_mode="error"
8)

Note

The examples above give the preferred input shape, but StateElement will consider the input as an array and try to viewcast input that does not match expected shape. That means that e.g. x = StateElement(2, discr_space, out_of_bounds_mode="error") is also considered valid input.

StateElement has an out_of_bounds_mode keyword argument (defaults to ‘warning’) that specifies what to do when a value is not contained in the space:

  • “error” –> raises a StateNotContainedError

  • “warning” –> warns with a StateNotContainedWarning, but accepts the input

  • “clip” –> clips the data to force it to belong to the space

  • “silent” –> Values not in the space are accepted silently (behavior should become equivalent to numpy.ndarray).

  • “raw” –> No data transformation is applied. This is faster than the other options, because the preprocessing of input data is short-circuited. However, this provides no tolerance for misspecified input.

Note

Broad/viewcasting and type casting are applied if necessary to the first four cases if the space has “contains” set to “soft”, but never with “raw”.

Other mechanisms

  • You can randomize its values
     1
     2# Random resets
     3x.reset()
     4y.reset()
     5z.reset()
     6
     7# Forced resets
     8x.reset(value=2)
     9y.reset(2)
    10z.reset(0.59 * numpy.ones((2, 2)))
    11
    12
    13# forced reset also raise input checks:
    14try:
    15    x.reset(value=5)
    16except StateNotContainedError:
    17    print("raised Error {} as expected".format(StateNotContainedError))
    18
    
  • You can iterate over them

     1# iterating over a continuous space like in Numpy: first over rows, then over columns
     2x = array_element(
     3    init=numpy.array([[0.2, 0.3], [0.4, 0.5]]),
     4    low=-numpy.ones((2, 2)),
     5    high=numpy.ones((2, 2)),
     6)
     7
     8for i, _x in enumerate(x):
     9    for j, _xx in enumerate(_x):
    10        print(_xx)
    11
    12# you can't iterate over a discrete set
    13x = cat_element(4)
    14with pytest.raises(TypeError):
    15    next(iter(x))
    
  • You can compare them. This includes a “hard” comparison, which checks if spaces are equal.

    1x = discrete_array_element(init=1, low=1, high=4)
    2w = discrete_array_element(init=1, low=1, high=3)
    3assert w.equals(x)
    4# Hard comparison checks also space
    5assert not w.equals(x, "hard")
    
  • You can extract values with or without spaces. Extracting the spaces together with the values can be done by a mechanism that abuses the slice notation.

     1# CoopIHC abuses Python's indexing mechanism; you can extract values together with spaces
     2x = cat_element(3)
     3assert x[..., {"space": True}] == x
     4assert x[...] == x
     5
     6x = array_element(
     7    init=numpy.array([[0.0, 0.1], [0.2, 0.3]]),
     8    low=-numpy.ones((2, 2)),
     9    high=numpy.ones((2, 2)),
    10)
    11assert x[0, 0] == 0.0
    12assert x[0, 0, {"space": True}] == array_element(init=0.0, low=-1, high=1)
    
  • You can cast values from one space to the other. This includes two modes of casting between discrete and continuous spaces.

     1
     2# --------------------------- Init viz.
     3
     4import matplotlib.pyplot as plt
     5
     6fig = plt.figure()
     7axd2c = fig.add_subplot(221)
     8axc2d = fig.add_subplot(222)
     9axc2c = fig.add_subplot(223)
    10axd2d = fig.add_subplot(224)
    11
    12# --------------------------  Discrete 2 Continuous
    13
    14_center = []
    15_edges = []
    16for i in [1, 2, 3]:
    17    x = discrete_array_element(init=i, low=1, high=3)
    18    y = array_element(init=i, low=-1.5, high=1.5)
    19    _center.append(x.cast(y, mode="center").tolist())
    20    _edges.append(x.cast(y, mode="edges").tolist())
    21
    22
    23axd2c.plot(numpy.array([1, 2, 3]), numpy.array(_center) - 0.05, "+", label="center")
    24axd2c.plot(numpy.array([1, 2, 3]), numpy.array(_edges) + 0.05, "o", label="edges")
    25axd2c.legend()
    26
    27# ------------------------   Continuous 2 Discrete
    28center = []
    29edges = []
    30for i in numpy.linspace(-1.5, 1.5, 100):
    31    x = discrete_array_element(init=i, low=1, high=3)
    32    y = array_element(init=i, low=-1.5, high=1.5)
    33
    34    ret_stateElem = y.cast(x, mode="center")
    35    center.append(ret_stateElem.tolist())
    36
    37    ret_stateElem = y.cast(x, mode="edges")
    38    edges.append(ret_stateElem.tolist())
    39
    40
    41axc2d.plot(
    42    numpy.linspace(-1.5, 1.5, 100), numpy.array(center) - 0.05, "+", label="center"
    43)
    44axc2d.plot(
    45    numpy.linspace(-1.5, 1.5, 100), numpy.array(edges) + 0.05, "o", label="edges"
    46)
    47axc2d.legend()
    48
    49
    50# ------------------------- Continuous2Continuous
    51
    52output = []
    53for i in numpy.linspace(-1, 1, 100):
    54    x = array_element(
    55        init=numpy.full((2, 2), i),
    56        low=numpy.full((2, 2), -1),
    57        high=numpy.full((2, 2), 1),
    58        dtype=numpy.float32,
    59    )
    60    y = array_element(
    61        low=numpy.full((2, 2), 0), high=numpy.full((2, 2), 4), dtype=numpy.float32
    62    )
    63    output.append(x.cast(y)[0, 0].tolist())
    64
    65axc2c.plot(numpy.linspace(-1, 1, 100), numpy.array(output), "-")
    66
    67# -------------------------- Discrete2Discrete
    68
    69
    70output = []
    71for i in [1, 2, 3, 4]:
    72    x = discrete_array_element(init=i, low=1, high=4)
    73    y = discrete_array_element(init=11, low=11, high=14)
    74    output.append(x.cast(y).tolist())
    75
    76axd2d.plot([1, 2, 3, 4], output, "+")
    77
    78
    79# ----------------------- show viz
    80plt.tight_layout()
    81# plt.show()
    82