from coopihc.base.State import State
from coopihc.observation.BaseObservationEngine import BaseObservationEngine
from coopihc.observation.utils import base_task_engine_specification
import copy
[docs]class RuleObservationEngine(BaseObservationEngine):
"""RuleObservationEngine
An observation engine that is specified by rules regarding each particular substate, using a so called mapping. Example usage is given below:
.. code-block:: python
obs_eng = RuleObservationEngine(mapping=mapping)
obs, reward = obs_eng.observe(game_state=example_game_state())
A mapping is any iterable where an item is:
(substate, subsubstate, _slice, _func, _args, _nfunc, _nargs)
The elements in this mapping are applied to create a particular component of the observation space, as follows
.. code-block:: python
observation_component = _nfunc(_func(state[substate][subsubstate][_slice], _args), _nargs)
For example, a valid mapping for the ``example_game_state`` mapping that states that everything should be observed except the game information is as follows:
.. code-block:: python
from coopihc.base.utils import example_game_state
print(example_game_state())
# Define mapping
mapping = [
("task_state", "position", slice(0, 1, 1), None, None, None, None),
("task_state", "targets", slice(0, 2, 1), None, None, None, None),
("user_state", "goal", slice(0, 1, 1), None, None, None, None),
("assistant_state", "beliefs", slice(0, 8, 1), None, None, None, None),
("user_action", "action", slice(0, 1, 1), None, None, None, None),
("assistant_action", "action", slice(0, 1, 1), None, None, None, None),
]
# Apply mapping
obseng = RuleObservationEngine(mapping=mapping)
obseng.observe(example_game_state())
As a more complex example, suppose we want to have an observation engine that behaves as above, but which doubles the observation on the ("user_state", "goal") StateElement. We also want to have a noisy observation of the ("task_state", "position") StateElement. We would need the following mapping:
.. code-block:: python
def f(observation, gamestate, *args):
gain = args[0]
return gain * observation
def g(observation, gamestate, *args):
return random.randint(0, 1) + observation
mapping = [
("task_state", "position", slice(0, 1, 1), None, None, g, ()),
("task_state", "targets", slice(0, 2, 1), None, None, None, None),
("user_state", "goal", slice(0, 1, 1), f, (2,), None, None),
("user_action", "action", slice(0, 1, 1), None, None, None, None),
("assistant_action", "action", slice(0, 1, 1), None, None, None, None),
]
.. note::
It is important to respect the signature of the functions you pass in the mapping (viz. f and g's signatures).
Typing out a mapping may be a bit laborious and hard to comprehend for collaborators; there are some shortcuts that make defining this engine easier.
Example usage:
.. code-block:: python
obs_eng = RuleObservationEngine(
deterministic_specification=engine_specification,
extradeterministicrules=extradeterministicrules,
extraprobabilisticrules=extraprobabilisticrules,
)
There are three types of rules:
1. Deterministic rules, which specify at a high level which states are observable or not, e.g.
.. code-block :: python
engine_specification = [
("game_info", "all"),
("task_state", "targets", slice(0, 1, 1)),
("user_state", "all"),
("assistant_state", None),
("user_action", "all"),
("assistant_action", "all"),
]
2. Extra deterministic rules, which add some specific rules to specific substates
.. code-block:: python
def f(observation, gamestate, *args):
gain = args[0]
return gain * observation
f_rule = {("user_state", "goal"): (f, (2,))}
extradeterministicrules = {}
extradeterministicrules.update(f_rule)
3. Extra probabilistic rules, which are used to e.g. add noise
.. code-block :: python
def g(observation, gamestate, *args):
return random.random() + observation
g_rule = {("task_state", "position"): (g, ())}
extraprobabilisticrules = {}
extraprobabilisticrules.update(g_rule)
.. warning ::
This observation engine handles deep copies, to make sure operations based on observations don't mess up the actual states. This might be slow though.
:param deterministic_specification: deterministic rules, defaults to base_task_engine_specification
:type deterministic_specification: list(tuples), optional
:param extradeterministicrules: extra deterministic rules, defaults to {}
:type extradeterministicrules: dict, optional
:param extraprobabilisticrules: extra probablistic rules, defaults to {}
:type extraprobabilisticrules: dict, optional
:param mapping: mapping, defaults to None
:type mapping: iterable, optional
"""
def __init__(
self,
*args,
deterministic_specification=base_task_engine_specification,
extradeterministicrules={},
extraprobabilisticrules={},
mapping=None,
**kwargs
):
super().__init__(*args, **kwargs)
self.deterministic_specification = deterministic_specification
self.extradeterministicrules = extradeterministicrules
self.extraprobabilisticrules = extraprobabilisticrules
self.mapping = mapping
# @BaseObservationEngine.get_params
@BaseObservationEngine.default_value
def observe(self, game_state=None):
"""observe
Wrapper around apply_mapping for interfacing with bundle.
:param game_state: game state
:type game_state: :py:class:`State <coopihc.base.State.State>`
:return: (observation, obs reward)
:rtype: tuple(:py:class:`State <coopihc.base.State.State>`, float)
"""
if self.mapping is None:
self.mapping = self.create_mapping(game_state)
obs = self.apply_mapping(game_state)
return obs, 0
[docs] def apply_mapping(self, game_state):
"""apply_mapping
Apply the rule mapping
:param game_state: game state
:type game_state: :py:class:`State <coopihc.base.State.State>`
:return: observation
:rtype: :py:class:`State <coopihc.base.State.State>`
"""
observation = State()
for (
substate,
subsubstate,
_slice,
_func,
_args,
_nfunc,
_nargs,
) in self.mapping:
if observation.get(substate) is None:
observation[substate] = State()
try:
_obs = game_state[substate][subsubstate][_slice, {"space": True}]
except IndexError: # 0-D arrays
_obs = game_state[substate][subsubstate][..., {"space": True}]
except KeyError: # If incomplete state is passed
continue
copied = False
if _func:
_obs = copy.copy(_obs)
copied = True
if _args:
_obs = _func(_obs, game_state, *_args)
else:
_obs = _func(_obs, game_state)
else:
_obs = _obs
if _nfunc:
if copied == False:
_obs = copy.copy(_obs)
if _nargs:
_obs = _nfunc(_obs, game_state, *_nargs)
else:
_obs = _nfunc(_obs, game_state)
else:
_obs = _obs
# observation[substate][subsubstate] = copy.copy(
# game_state[substate][subsubstate]
# ) # probably useless
observation[substate][subsubstate] = _obs
return observation
[docs] def create_mapping(self, game_state):
"""create_mapping
Create mapping from the high level rules specified in the Rule Engine.
:param game_state: game state
:type game_state: :py:class:`State <coopihc.base.State.State>`
:return: Mapping
:rtype: iterable
"""
(
observation_engine_specification,
extradeterministicrules,
extraprobabilisticrules,
) = (
self.deterministic_specification,
self.extradeterministicrules,
self.extraprobabilisticrules,
)
mapping = []
for substate, *rest in observation_engine_specification:
subsubstate = rest[0]
# if substate == "turn_index":
# continue
if subsubstate == "all":
for key, value in game_state[substate].items():
v = extradeterministicrules.get((substate, key))
if v is not None:
f, a = v
else:
f, a = None, None
w = extraprobabilisticrules.get((substate, key))
if w is not None:
g, b = w
else:
g, b = None, None
# deal with ints
try:
_len = len(value)
except TypeError:
_len = 1
mapping.append((substate, key, slice(0, _len, 1), f, a, g, b))
elif subsubstate is None:
pass
else:
v = extradeterministicrules.get((substate, subsubstate))
if v is not None:
f, a = v
else:
f, a = None, None
w = extraprobabilisticrules.get((substate, subsubstate))
if w is not None:
g, b = w
else:
g, b = None, None
try:
_slice = rest[1]
except IndexError:
_slice = "all"
if _slice == "all":
mapping.append(
(
substate,
subsubstate,
slice(0, len(game_state[substate][subsubstate]), 1),
f,
a,
g,
b,
)
)
elif isinstance(_slice, slice):
mapping.append((substate, subsubstate, _slice, f, a, g, b))
return mapping