Source code for coopihc.bundle.wrappers.Train

from coopihc.base.StateElement import StateElement
from coopihc.base.Space import Numeric, CatSet

import numpy
import gym
from collections import OrderedDict
from abc import ABC, abstractmethod


[docs]class TrainGym2SB3ActionWrapper(gym.ActionWrapper): """TrainGym2SB3ActionWrapper Wrapper that flatten all spaces to boxes, using one-hot encoding for discrete spaces. While this wrapper will likely work for all cases, it may sometimes be more effective to code your own actionwrapper to avoid one-hot encoding. :param gym: [description] :type gym: [type] """ def __init__(self, env): super().__init__(env) self.action_space = gym.spaces.utils.flatten_space(self.env.action_space) def action(self, action): return gym.spaces.utils.unflatten(self.env.action_space, action) def reverse_action(self, action): return gym.spaces.utils.flatten(action)
[docs]class TrainGym(gym.Env): """Generic Wrapper to make bundles compatibles with gym.Env This is a Wrapper to make a Bundle compatible with gym.Env. Read more on the Train class. :param bundle: bundle to convert to a gym.Env :type bundle: `Bundle <coopihc.bundle.Bundle.Bundle>` :param train_user: whether to train the user, defaults to True :type train_user: bool, optional :param train_assistant: whether to train the assistant, defaults to True :type train_assistant: bool, optional :param observation_dict: to filter out observations, you can apply a dictionnary, defaults to None. e.g.: .. code-block:: python filterdict = OrderedDict( { "user_state": OrderedDict({"goal": 0}), "task_state": OrderedDict({"x": 0}), } ) You can always filter out observations later using an ObservationWrapper. Difference in performance between the two approaches is unknown. :type observation_dict: collections.OrderedDict, optional :param reset_dic: During training, the bundle will be repeatedly reset. Pass the reset_dic here if needed (see Bundle reset mechanism), defaults to {} :type reset_dic: dict, optional :param reset_turn: During training, the bundle will be repeatedly reset. Pass the reset_turn here (see Bundle reset_turn mechanism), defaults to None, which selects either 1 if the user is trained else 3 :type reset_turn: int, optional """ def __init__( self, bundle, *args, train_user=False, train_assistant=False, observation_dict=None, reset_dic={}, reset_turn=None, filter_observation=None, **kwargs, ): self.train_user = train_user self.train_assistant = train_assistant self.bundle = bundle self.observation_dict = observation_dict self.reset_dic = reset_dic self.filter_observation = filter_observation if reset_turn is None: if self.train_user: self.reset_turn = 1 if train_assistant: # override reset_turn if train_assistant is True self.reset_turn = 3 else: self.reset_turn = reset_turn self._convertor = GymConvertor(filter_observation=filter_observation) # The asymmetry of these two should be resolved. Currently, some fiddling is needed due to Issue # 58 https://github.com/jgori-ouistiti/CoopIHC/issues/58 . It is expected that when issue 58 is resolved, this code can be cleaned up. self.action_space = self.get_action_space() self.observation_space = self.get_observation_space() # Below: Using ordereddict here is forced due to Open AI gym's behavior: when initializing the Dict space, it tries to order the dict by keys, which may change the order of the dict entries. This is actually useless since Python 3.7 because dicts are ordered by default.
[docs] def get_action_space(self): """get_action_space Create a gym.spaces.Dict out of the action states of the Bundle. """ # ----- Init Action space ------- action_dict = OrderedDict({}) if self.train_user: try: for i, _action in enumerate(self.bundle.user.action): action_dict.update( {f"user_action_{i}": self.convert_space(_action)} ) except TypeError: # Catch single actions action_dict.update( {f"user_action": self.convert_space(self.bundle.user.action)} ) if self.train_assistant: try: for i, _action in enumerate(self.bundle.assistant.action): action_dict.update( {f"assistant_action_{i}": self.convert_space(_action)} ) except TypeError: # Catch single actions action_dict.update( { f"assistant_action": self.convert_space( self.bundle.assistant.action ) } ) return gym.spaces.Dict(action_dict)
[docs] def get_observation_space(self): """get_observation_space Same as get_action_space for observations. """ self.bundle.reset(go_to=self.reset_turn) # ------- Init Observation space if self.train_user and self.train_assistant: raise NotImplementedError( "Currently this wrapper can not deal with simultaneous training of users and assistants." ) if self.train_user: return self.get_agent_observation_space("user") if self.train_assistant: return self.get_agent_observation_space("assistant")
def get_agent_observation_space(self, agent): observation_dict = OrderedDict({}) observation = getattr(self.bundle, agent).observation if self.filter_observation is not None: observation = observation.filter( mode="stateelement", filterdict=self.filter_observation ) for key, value in observation.items(): # for key, value in getattr(self.bundle, agent).observation.items(): if key == "user_action" or key == "assistant_action": observation_dict.update({key: self.convert_space(value["action"])}) else: observation_dict.update( { __key: self.convert_space(__value) for __key, __value in value.items() } ) return gym.spaces.Dict(observation_dict)
[docs] def reset(self): self.bundle.reset(dic=self.reset_dic, go_to=self.reset_turn) if self.train_user and self.train_assistant: raise NotImplementedError if self.train_user: return self._convertor.filter_gamestate(self.bundle.user.observation) if self.train_assistant: return self._convertor.filter_gamestate(self.bundle.assistant.observation)
[docs] def step(self, action): user_action = action.get("user_action", None) assistant_action = action.get("assistant_action", None) obs, rewards, flag = self.bundle.step( user_action=user_action, assistant_action=assistant_action ) if self.train_user and self.train_assistant: raise NotImplementedError if self.train_user: obs = self._convertor.filter_gamestate(self.bundle.user.observation) if self.train_assistant: obs = self._convertor.filter_gamestate(self.bundle.assistant.observation) return ( obs, float(sum(rewards.values())), flag, {"name": "CoopIHC Bundle {}".format(str(self.bundle))}, )
def convert_space(self, object): if isinstance(object, StateElement): object = object.space return self._convertor.convert_space(object)
[docs] def render(self, mode): """See Bundle and gym API :meta public: """ self.bundle.render(mode)
[docs] def close(self): """See Bundle and gym API :meta public: """ self.bundle.close()
[docs]class RLConvertor(ABC): """RLConvertor An object, who should be subclassed that helps convert spaces from Bundles to another library. :param interface: API target for conversion, defaults to "gym" :type interface: str, optional """ def __init__(self, interface="gym", **kwargs): self.interface = interface if self.interface != "gym": raise NotImplementedError @abstractmethod def convert_space(self, space): pass @abstractmethod def filter_gamestate(self, gamestate, observation_mapping): pass
[docs]class GymConvertor(RLConvertor): """GymConvertor Convertor to convert spaces from Bundle to Gym. .. note:: Code is a little messy. Refactoring together with Train and TrainGym would be beneficial. :param RLConvertor: [description] :type RLConvertor: [type] """ def __init__(self, filter_observation=None, **kwargs): super().__init__(interface="gym", **kwargs) self._filter_observation = filter_observation def convert_space(self, space): if isinstance(space, Numeric): return gym.spaces.Box( low=numpy.atleast_1d(space.low), high=numpy.atleast_1d(space.high), dtype=space.dtype, ) elif isinstance(space, CatSet): return gym.spaces.Discrete(space.N)
[docs] def filter_gamestate(self, gamestate): """filter_gamestate converts a CoopIHC observation to a valid Gym observation """ dic = OrderedDict({}) for k, v in gamestate.filter( mode="array-Gym", filterdict=self._filter_observation ).items(): # Hack, see Issue # 58 https://github.com/jgori-ouistiti/CoopIHC/issues/58 try: _key, _value = next(iter(v.items())) except StopIteration: continue if k == "user_action" or k == "assistant_action": v[k] = v.pop("action") _key = k dic.update(v) return dic