Source code for coopihc.interactiontask.PipeTaskWrapper

from abc import ABC, abstractmethod
from coopihc.interactiontask.InteractionTask import InteractionTask


[docs]class PipeTaskWrapper(InteractionTask, ABC): """PipeTaskWrapper A Wrapper for tasks so that messages are passed through a pipe. Subclass this task to use tasks defined externally (e.g. that pass messages via websockets to a server which forwards the message to a task via a pipe) .. note:: Need to explain interface here :param task: task to wrap :type task: :py:class:`InteractionTask<coopihc.interactiontask.InteractionTask.InteractionTask` :param pipe: pipe :type pipe: subprocess.Pipe """ def __init__(self, task, pipe): self.__dict__ = task.__dict__ self.task = task self.pipe = pipe self.pipe.send({"type": "init", "parameters": self.parameters}) is_done = False while True: self.pipe.poll(None) received_state = self.pipe.recv() # This assumes that the final message sent by the client is a task_state message. Below should be changed to remove that assumption (i.e. client can send whatever order) if received_state["type"] == "task_state": is_done = True self.update_state(received_state) if is_done: break def __getattr__(self, attr): if self.__dict__: return getattr(self.__dict__["task"], attr) else: # should never happen pass def __setattr__(self, name, value): if name == "__dict__" or name == "task": super().__setattr__(name, value) return if self.__dict__: setattr(self.__dict__["task"], name, value)
[docs] def update_state(self, state): """update_state Remove the 'type' entry from the state dictionnary :param state: state received via pipe :type state: dictionnary """ if state["type"] == "task_state": del state["type"] self.update_task_state(state) elif state["type"] == "user_state": del state["type"] self.update_user_state(state)
[docs] @abstractmethod def update_task_state(self, state): """update_task_state Redefine this. Example `here <https://jgori-ouistiti.github.io/CoopIHC-zoo/_modules/coopihczoo/pointing/envs.html#DiscretePointingTaskPipeWrapper>`_ :param state: state received via pipe :type state: dictionnary """ pass
[docs] @abstractmethod def update_user_state(self, state): """update_user_state See update_task_state :param state: state received via pipe :type state: dictionnary """ pass
[docs] def on_user_action(self, *args, **kwargs): """on_user_action 1. Transform user action into dictionnary with appropriate interface 2. Send message over pipe 3. Wait for pipe message 4. Update state and return :return: (task state, task reward, is_done flag, {}) :rtype: tuple(:py:class:`State<coopihc.base.State.State>`, float, boolean, dictionnary) """ super().on_user_action(*args, **kwargs) user_action_msg = { "type": "user_action", "value": self.bundle.game_state["user_action"]["action"].serialize(), } self.pipe.send(user_action_msg) self.pipe.poll(None) received_dic = self.pipe.recv() received_state = received_dic["state"] self.update_state(received_state) return self.state, received_dic["reward"], received_dic["is_done"], {}
[docs] def on_assistant_action(self, *args, **kwargs): """on_assistant_action Same as on_user_action :return: (task state, task reward, is_done flag, {}) :rtype: tuple(:py:class:`State<coopihc.base.State.State>`, float, boolean, dictionnary) """ super().on_assistant_action(*args, **kwargs) assistant_action_msg = { "type": "assistant_action", "value": self.bundle.game_state["assistant_action"]["action"].serialize(), } self.pipe.send(assistant_action_msg) self.pipe.poll(None) received_dic = self.pipe.recv() received_state = received_dic["state"] self.update_state(received_state) return self.state, received_dic["reward"], received_dic["is_done"], {}
[docs] def reset(self, dic=None): """reset 1. Send reset dic over pipe 2. Wait for pipe message 3. Update state and return .. note :: verify the dic=None signature :param dic: reset dic, defaults to None :type dic: dictionnary, optional :return: Task state :rtype: :py:class:`State<coopihc.base.State.State>` """ super().reset(dic=dic) reset_msg = {"type": "reset", "reset_dic": dic} self.pipe.send(reset_msg) self.pipe.poll(None) received_state = self.pipe.recv() self.update_state(received_state) self.bundle.reset(task=False) return self.state