from random import random
from coopihc.base.State import State
from coopihc.base.elements import discrete_array_element, array_element, cat_element
from coopihc.base.elements import discrete_array_element, cat_element
import numpy
import yaml
import matplotlib.pyplot as plt
import copy
[docs]class BaseBundle:
"""Main class for bundles.
Main class for bundles. This class is subclassed by Bundle, which defines the interface with which to interact.
A bundle combines a task with a user and an assistant. The bundle creates the ``game_state`` by combining the task, user and assistant states with the turn index and both agent's actions.
The bundle takes care of all the messaging between classes, making sure the gamestate and all individual states are synchronized at all times.
The bundle implements a forced reset mechanism, where each state of the bundle can be forced to a particular state via a dictionnary mechanism (see :py:func:reset)
The bundle also takes care of rendering each of the three component in a single place.
:param task: (:py:class:`coopihc.interactiontask.InteractionTask.InteractionTask`) A task that inherits from ``InteractionTask``
:param user: (:py:class:`coopihc.agents.BaseAgent.BaseAgent`) a user which inherits from ``BaseAgent``
:param assistant: (:py:class:`coopihc.agents.BaseAgent.BaseAgent`) an assistant which inherits from ``BaseAgent``
:meta public:
"""
turn_dict = {
"after_assistant_action": 0,
"before_user_action": 1,
"after_user_action": 2,
"before_assistant_action": 3,
}
def __init__(
self,
task,
user,
assistant,
*args,
reset_random=False,
reset_start_after=-1,
reset_go_to=0,
**kwargs,
):
self._reset_random = reset_random
self._reset_start_after = reset_start_after
self._reset_go_to = reset_go_to
self.kwargs = kwargs
self.task = task
self.task.bundle = self
self.user = user
self.user.bundle = self
self.assistant = assistant
self.assistant.bundle = self
# Form complete game state
self.game_state = State()
turn_index = cat_element(
N=4, init=0, out_of_bounds_mode="raw", dtype=numpy.int8
)
round_index = discrete_array_element(
init=0, low=0, high=numpy.iinfo(numpy.int64).max, out_of_bounds_mode="raw"
)
self.game_state["game_info"] = State()
self.game_state["game_info"]["turn_index"] = turn_index
self.game_state["game_info"]["round_index"] = round_index
self.game_state["task_state"] = task.state
self.game_state["user_state"] = user.state
self.game_state["assistant_state"] = assistant.state
# here there is a small caveat: you can not access action states in the game_state at finit, you have to pass through the agent instead. This is due to the current way of creating the game_state.
self.task.finit()
self.user.finit()
self.assistant.finit()
if user.policy is not None:
self.game_state["user_action"] = user.policy.action_state
else:
self.game_state["user_action"] = State()
self.game_state["user_action"]["action"] = array_element()
if assistant.policy is not None:
self.game_state["assistant_action"] = assistant.policy.action_state
else:
self.game_state["assistant_action"] = State()
self.game_state["assistant_action"]["action"] = array_element()
# This will not work sometimes
# self.task.finit()
# self.user.finit()
# self.assistant.finit()
# Needed for render
self.active_render_figure = None
self.figure_layout = [211, 223, 224]
self.rendered_mode = None
self.render_perm = False
self.playspeed = 0.1
def __repr__(self):
"""__repr__
Pretty representation for Bundles.
:return: pretty bundle print
:rtype: string
"""
return "{}\n".format(self.__class__.__name__) + yaml.safe_dump(
self.__content__()
)
def __content__(self):
"""__content__
Custom class representation
:return: class repr
:rtype: dictionnary
"""
return {
"Task": self.task.__content__(),
"User": self.user.__content__(),
"Assistant": self.assistant.__content__(),
}
@property
def parameters(self):
return {
**self.task._parameters,
**self.user._parameters,
**self.assistant._parameters,
}
@property
def turn_number(self):
"""turn_number
The turn number in the game (0 to 3)
:return: turn number
:rtype: numpy.ndarray
"""
return self.game_state["game_info"]["turn_index"]
@turn_number.setter
def turn_number(self, value):
self._turn_number = value
self.game_state["game_info"]["turn_index"] = value
@property
def round_number(self):
"""round_number
The round number in the game (0 to N)
:return: turn number
:rtype: numpy.ndarray
"""
return self.game_state["game_info"]["round_index"]
@round_number.setter
def round_number(self, value):
self._round_number = value
self.game_state["game_info"]["round_index"] = value
@property
def state(self):
return self.game_state
[docs] def reset(
self,
go_to=None,
start_after=None,
task=True,
user=True,
assistant=True,
dic={},
random_reset=False,
):
"""Reset bundle.
1. Reset the game and start at a specific turn number.
2. select which components to reset
3. forced reset mechanism using dictionnaries
Example:
.. code-block:: python
new_target_value = self.game_state["task_state"]["targets"]
new_fixation_value = self.game_state["task_state"]["fixation"]
)
reset_dic = {"task_state": {"targets": new_target_value, "fixation": new_fixation_value}}
self.reset(dic=reset_dic, turn = 1)
Will set the substates "targets" and "fixation" of state "task_state" to some value.
.. note ::
If subclassing BaseBundle, make sure to call super().reset() in the new reset method.
:param turn: game turn number. Can also be set globally at the bundle level by passing the "reset_turn" keyword argument, defaults to 0
:type turn: int, optional
:param start_after: which turn to start at (allows skipping some turns during reset), defaults to 0
:type start_after: int, optional
:param task: reset task?, defaults to True
:type task: bool, optional
:param user: reset user?, defaults to True
:type user: bool, optional
:param assistant: reset assistant?, defaults to True
:type assistant: bool, optional
:param dic: reset_dic, defaults to {}
:type dic: dict, optional
:param random_reset: whether during resetting values should be randomized or not if not set by a reset dic, default to False
:type random_reset: bool, optional
:return: new game state
:rtype: :py:class:`State<coopihc.base.State.State>`
"""
if go_to is None:
go_to = self._reset_go_to
if start_after is None:
start_after = self._reset_start_after
random_reset = self._reset_random or random_reset
if task:
task_dic = dic.get("task_state")
self.task._base_reset(
dic=task_dic,
random=random_reset,
)
if user:
user_dic = dic.get("user_state")
self.user._base_reset(
dic=user_dic,
random=random_reset,
)
if assistant:
assistant_dic = dic.get("assistant_state")
self.assistant._base_reset(
dic=assistant_dic,
random=random_reset,
)
self.round_number = 0
if not isinstance(go_to, (numpy.integer, int)):
go_to = self.turn_dict[go_to]
if not isinstance(start_after, (numpy.integer, int)):
start_after = self.turn_dict[start_after]
self.turn_number = go_to
if go_to == 0 and start_after + 1 == 0:
return self.game_state
if start_after <= go_to:
if go_to >= 1 and start_after + 1 <= 1:
self._user_first_half_step()
if go_to >= 2 and start_after + 1 <= 2:
user_action, _ = self.user.take_action(increment_turn=False)
self.user.action = user_action
self._user_second_half_step(user_action)
if go_to >= 3 and start_after + 1 <= 3:
self._assistant_first_half_step()
else:
raise ValueError(
f"start_after ({start_after}) can not be after go_to ({go_to}). You can likely use a combination of reset and step to achieve what you are looking for"
)
return self.game_state
def quarter_step(self, user_action=None, assistant_action=None, **kwargs):
return self.step(
user_action=user_action,
assistant_action=assistant_action,
go_to=(int(self.turn_number) + 1) % 4,
)
[docs] def step(self, user_action=None, assistant_action=None, go_to=None, **kwargs):
"""Play a round
Play a round of the game. A round consists in 4 turns. If go_to is not None, the round is only played until that turn.
If a user action and assistant action are passed as arguments, then these are used as actions to play the round. Otherwise, these actions are sampled from each agent's policy.
:param user action: user action
:type: any
:param assistant action: assistant action
:type: any
:param go_to: turn at which round stops, defaults to None
:type go_to: int, optional
:return: gamestate, reward, game finished flag
:rtype: tuple(:py:class:`State<coopihc.base.State.State>`, collections.OrderedDict, boolean)
"""
if go_to is None:
go_to = int(self.turn_number)
if not isinstance(go_to, (numpy.integer, int)):
go_to = self.turn_dict[go_to]
_started = False
rewards = {}
rewards["user_observation_reward"] = 0
rewards["user_inference_reward"] = 0
rewards["user_policy_reward"] = 0
rewards["first_task_reward"] = 0
rewards["assistant_observation_reward"] = 0
rewards["assistant_inference_reward"] = 0
rewards["assistant_policy_reward"] = 0
rewards["second_task_reward"] = 0
while self.turn_number != go_to or (not _started):
_started = True
# User observes and infers
if self.turn_number == 0 and "no-user" != self.kwargs.get("name"):
(
user_obs_reward,
user_infer_reward,
) = self._user_first_half_step()
(
rewards["user_observation_reward"],
rewards["user_inference_reward"],
) = (user_obs_reward, user_infer_reward)
# User takes action and receives reward from task
elif self.turn_number == 1 and "no-user" != self.kwargs.get("name"):
if user_action is None:
user_action, user_policy_reward = self.user.take_action(
increment_turn=False
)
else:
self.user.action = user_action
user_policy_reward = 0
task_reward, is_done = self._user_second_half_step(user_action)
rewards["user_policy_reward"] = user_policy_reward
rewards["first_task_reward"] = task_reward
if is_done:
return self.game_state, rewards, is_done
elif self.turn_number == 2 and "no-assistant" == self.kwargs.get("name"):
self.round_number = self.round_number + 1
# Assistant observes and infers
elif self.turn_number == 2 and "no-assistant" != self.kwargs.get("name"):
(
assistant_obs_reward,
assistant_infer_reward,
) = self._assistant_first_half_step()
(
rewards["assistant_observation_reward"],
rewards["assistant_inference_reward"],
) = (assistant_obs_reward, assistant_infer_reward)
# Assistant takes action and receives reward from task
elif self.turn_number == 3 and "no-assistant" != self.kwargs.get("name"):
if assistant_action is None:
(
assistant_action,
assistant_policy_reward,
) = self.assistant.take_action(increment_turn=False)
else:
self.assistant.action = assistant_action
assistant_policy_reward = 0
task_reward, is_done = self._assistant_second_half_step(
assistant_action
)
rewards["assistant_policy_reward"] = assistant_policy_reward
rewards["second_task_reward"] = task_reward
if is_done:
return self.game_state, rewards, is_done
self.round_number = self.round_number + 1
self.turn_number = (self.turn_number + 1) % 4
return self.game_state, rewards, False
[docs] def render(self, mode, *args, **kwargs):
"""render
Combines all render methods.
:param mode: "text" or "plot"
:param type: string
:meta public:
"""
self.rendered_mode = mode
if "text" in mode:
print("\n")
print("Round number {}".format(self.round_number.tolist()))
print("Task Render")
self.task.render(mode="text", *args, **kwargs)
print("User Render")
self.user.render(mode="text", *args, **kwargs)
print("Assistant Render")
self.assistant.render(mode="text", *args, **kwargs)
if "log" in mode:
self.task.render(mode="log", *args, **kwargs)
self.user.render(mode="log", *args, **kwargs)
self.assistant.render(mode="log", *args, **kwargs)
if "plot" in mode:
if self.active_render_figure:
plt.pause(self.playspeed)
self.task.render(
ax_task=self.axtask,
ax_user=self.axuser,
ax_assistant=self.axassistant,
mode="plot",
**kwargs,
)
self.user.render(
ax_task=self.axtask,
ax_user=self.axuser,
ax_assistant=self.axassistant,
mode="plot",
**kwargs,
)
self.assistant.render(
ax_task=self.axtask,
ax_user=self.axuser,
ax_assistant=self.axassistant,
mode="plot",
**kwargs,
)
self.fig.canvas.draw()
else:
self.active_render_figure = True
self.fig = plt.figure()
self.axtask = self.fig.add_subplot(self.figure_layout[0])
self.axtask.set_title("Task State")
self.axuser = self.fig.add_subplot(self.figure_layout[1])
self.axuser.set_title("User State")
self.axassistant = self.fig.add_subplot(self.figure_layout[2])
self.axassistant.set_title("Assistant State")
self.task.render(
ax_task=self.axtask,
ax_user=self.axuser,
ax_assistant=self.axassistant,
mode="plot",
**kwargs,
)
self.user.render(
ax_task=self.axtask,
ax_user=self.axuser,
ax_assistant=self.axassistant,
mode="plot",
**kwargs,
)
self.assistant.render(
ax_task=self.axtask,
ax_user=self.axuser,
ax_assistant=self.axassistant,
mode="plot",
**kwargs,
)
self.fig.show()
plt.tight_layout()
if not ("plot" in mode or "text" in mode):
self.task.render(None, mode=mode, *args, **kwargs)
self.user.render(None, mode=mode, *args, **kwargs)
self.assistant.render(None, mode=mode, *args, **kwargs)
[docs] def close(self):
"""close
Close the bundle once the game is finished.
"""
if self.active_render_figure:
plt.close(self.fig)
# self.active_render_figure = None
def _user_first_half_step(self):
"""_user_first_half_step
Turn 1, where the user observes the game state and updates its state via inference.
:return: user observation and inference reward
:rtype: tuple(float, float)
"""
if not self.kwargs.get("onreset_deterministic_first_half_step"):
user_obs_reward, user_infer_reward = self.user._agent_step()
else:
# Store the probabilistic rules
store = self.user.observation_engine.extraprobabilisticrules
# Remove the probabilistic rules
self.user.observation_engine.extraprobabilisticrules = {}
# Generate an observation without generating an inference
user_obs_reward, user_infer_reward = self.user._agent_step(infer=False)
# Reposition the probabilistic rules, and reset mapping
self.user.observation_engine.extraprobabilisticrules = store
self.user.observation_engine.mapping = None
self.kwargs["onreset_deterministic_first_half_step"] = False
return user_obs_reward, user_infer_reward
def _user_second_half_step(self, user_action):
"""_user_second_half_step
Turn 2, where the operaror takes an action.
:param user_action: user action
:param type: Any
:return: task reward, task done?
:rtype: tuple(float, boolean)
"""
# Play user's turn in the task
task_state, task_reward, is_done = self.task.base_on_user_action(
user_action=user_action
)
return task_reward, is_done
def _assistant_first_half_step(self):
"""_assistant_first_half_step
Turn 3, where the assistant observes the game state and updates its state via inference.
:return: assistant observation and inference reward
:rtype: tuple(float, float)
"""
(
assistant_obs_reward,
assistant_infer_reward,
) = self.assistant._agent_step()
return assistant_obs_reward, assistant_infer_reward
def _assistant_second_half_step(self, assistant_action):
"""_assistant_second_half_step
Turn 4, where the assistant takes an action.
:param user_action: assistant action
:param type: Any
:return: task reward, task done?
:rtype: tuple(float, boolean)
"""
# Play assistant's turn in the task
task_state, task_reward, is_done = self.task.base_on_assistant_action(
assistant_action=assistant_action
)
return task_reward, is_done
def _on_user_action(self, *args):
"""Turns 1 and 2
:param \*args: either provide the user action or not. If no action is provided the action is determined by the agent's policy using sample()
:param type: (None or list)
:return: user observation, inference, policy and task rewards, game is done flag
:return type: tuple(float, float, float, float, bool)
"""
user_obs_reward, user_infer_reward = self._user_first_half_step()
try:
# If human input is provided
user_action = args[0]
except IndexError:
# else sample from policy
user_action, user_policy_reward = self.user.take_action(
increment_turn=False
)
self.user.action = user_action
task_reward, is_done = self._user_second_half_step(user_action)
return (
user_obs_reward,
user_infer_reward,
user_policy_reward,
task_reward,
is_done,
)
def _on_assistant_action(self, *args):
"""Turns 3 and 4
:param \*args: either provide the assistant action or not. If no action is provided the action is determined by the agent's policy using sample()
:param type: (None or list)
:return: assistant observation, inference, policy and task rewards, game is done flag
:return type: tuple(float, float, float, float, bool)
"""
(
assistant_obs_reward,
assistant_infer_reward,
) = self._assistant_first_half_step()
try:
# If human input is provided
assistant_action = args[0]
except IndexError:
# else sample from policy
(
assistant_action,
assistant_policy_reward,
) = self.assistant.take_action(increment_turn=False)
self.assistant.action = assistant_action
task_reward, is_done = self._assistant_second_half_step(assistant_action)
return (
assistant_obs_reward,
assistant_infer_reward,
assistant_policy_reward,
task_reward,
is_done,
)