StateElement
A StateElement
is a a combination of a value and a space. Under the hood, StateElement
subclasses numpy.ndarray
; essentially, it adds a layer that checks whether the values are contained in the space (and what to do if not). As a result, many NumPy methods will work.
Instantiating a StateElement is straightforward.
1# Value in a Categorical set
2x = cat_element(3)
3# Discrete value
4y = discrete_array_element(low=2, high=5, init=2)
5# Continuous Value
6z = array_element(
7 low=-numpy.ones((2, 2)), high=numpy.ones((2, 2)), out_of_bounds_mode="error"
8)
Note
The examples above give the preferred input shape, but StateElement will consider the input as an array and try to viewcast input that does not match expected shape. That means that e.g. x = StateElement(2, discr_space, out_of_bounds_mode="error")
is also considered valid input.
StateElement
has an out_of_bounds_mode
keyword argument (defaults to ‘warning’) that specifies what to do when a value is not contained in the space:
“error” –> raises a
StateNotContainedError
“warning” –> warns with a
StateNotContainedWarning
, but accepts the input“clip” –> clips the data to force it to belong to the space
“silent” –> Values not in the space are accepted silently (behavior should become equivalent to numpy.ndarray).
“raw” –> No data transformation is applied. This is faster than the other options, because the preprocessing of input data is short-circuited. However, this provides no tolerance for misspecified input.
Note
Broad/viewcasting and type casting are applied if necessary to the first four cases if the space has “contains” set to “soft”, but never with “raw”.
Other mechanisms
- You can randomize its values
1 2# Random resets 3x.reset() 4y.reset() 5z.reset() 6 7# Forced resets 8x.reset(value=2) 9y.reset(2) 10z.reset(0.59 * numpy.ones((2, 2))) 11 12 13# forced reset also raise input checks: 14try: 15 x.reset(value=5) 16except StateNotContainedError: 17 print("raised Error {} as expected".format(StateNotContainedError)) 18
You can iterate over them
1# iterating over a continuous space like in Numpy: first over rows, then over columns 2x = array_element( 3 init=numpy.array([[0.2, 0.3], [0.4, 0.5]]), 4 low=-numpy.ones((2, 2)), 5 high=numpy.ones((2, 2)), 6) 7 8for i, _x in enumerate(x): 9 for j, _xx in enumerate(_x): 10 print(_xx) 11 12# you can't iterate over a discrete set 13x = cat_element(4) 14with pytest.raises(TypeError): 15 next(iter(x))
You can compare them. This includes a “hard” comparison, which checks if spaces are equal.
1x = discrete_array_element(init=1, low=1, high=4) 2w = discrete_array_element(init=1, low=1, high=3) 3assert w.equals(x) 4# Hard comparison checks also space 5assert not w.equals(x, "hard")
You can extract values with or without spaces. Extracting the spaces together with the values can be done by a mechanism that abuses the slice notation.
1# CoopIHC abuses Python's indexing mechanism; you can extract values together with spaces 2x = cat_element(3) 3assert x[..., {"space": True}] == x 4assert x[...] == x 5 6x = array_element( 7 init=numpy.array([[0.0, 0.1], [0.2, 0.3]]), 8 low=-numpy.ones((2, 2)), 9 high=numpy.ones((2, 2)), 10) 11assert x[0, 0] == 0.0 12assert x[0, 0, {"space": True}] == array_element(init=0.0, low=-1, high=1)
You can cast values from one space to the other. This includes two modes of casting between discrete and continuous spaces.
1 2# --------------------------- Init viz. 3 4import matplotlib.pyplot as plt 5 6fig = plt.figure() 7axd2c = fig.add_subplot(221) 8axc2d = fig.add_subplot(222) 9axc2c = fig.add_subplot(223) 10axd2d = fig.add_subplot(224) 11 12# -------------------------- Discrete 2 Continuous 13 14_center = [] 15_edges = [] 16for i in [1, 2, 3]: 17 x = discrete_array_element(init=i, low=1, high=3) 18 y = array_element(init=i, low=-1.5, high=1.5) 19 _center.append(x.cast(y, mode="center").tolist()) 20 _edges.append(x.cast(y, mode="edges").tolist()) 21 22 23axd2c.plot(numpy.array([1, 2, 3]), numpy.array(_center) - 0.05, "+", label="center") 24axd2c.plot(numpy.array([1, 2, 3]), numpy.array(_edges) + 0.05, "o", label="edges") 25axd2c.legend() 26 27# ------------------------ Continuous 2 Discrete 28center = [] 29edges = [] 30for i in numpy.linspace(-1.5, 1.5, 100): 31 x = discrete_array_element(init=i, low=1, high=3) 32 y = array_element(init=i, low=-1.5, high=1.5) 33 34 ret_stateElem = y.cast(x, mode="center") 35 center.append(ret_stateElem.tolist()) 36 37 ret_stateElem = y.cast(x, mode="edges") 38 edges.append(ret_stateElem.tolist()) 39 40 41axc2d.plot( 42 numpy.linspace(-1.5, 1.5, 100), numpy.array(center) - 0.05, "+", label="center" 43) 44axc2d.plot( 45 numpy.linspace(-1.5, 1.5, 100), numpy.array(edges) + 0.05, "o", label="edges" 46) 47axc2d.legend() 48 49 50# ------------------------- Continuous2Continuous 51 52output = [] 53for i in numpy.linspace(-1, 1, 100): 54 x = array_element( 55 init=numpy.full((2, 2), i), 56 low=numpy.full((2, 2), -1), 57 high=numpy.full((2, 2), 1), 58 dtype=numpy.float32, 59 ) 60 y = array_element( 61 low=numpy.full((2, 2), 0), high=numpy.full((2, 2), 4), dtype=numpy.float32 62 ) 63 output.append(x.cast(y)[0, 0].tolist()) 64 65axc2c.plot(numpy.linspace(-1, 1, 100), numpy.array(output), "-") 66 67# -------------------------- Discrete2Discrete 68 69 70output = [] 71for i in [1, 2, 3, 4]: 72 x = discrete_array_element(init=i, low=1, high=4) 73 y = discrete_array_element(init=11, low=11, high=14) 74 output.append(x.cast(y).tolist()) 75 76axd2d.plot([1, 2, 3, 4], output, "+") 77 78 79# ----------------------- show viz 80plt.tight_layout() 81# plt.show() 82