Source code for coopihc.policy.RLPolicy

import copy

from coopihc.base.State import State
from coopihc.policy.BasePolicy import BasePolicy
from coopihc.bundle.wrappers.Train import TrainGym

# ======================= RL Policy


[docs]class RLPolicy(BasePolicy): """Wrap a trained net as a CoopIHC policy. A policy object compatible with CoopIHC that wraps a policy that was trained via Deep Reinforcement learning. Example code: .. code-block:: python # action_state action_state = State() action_state["action"] = StateElement(0, autospace([-5 + i for i in range(11)])) # env env = TrainGym( bundle, train_user=True, train_assistant=False, ) # Using PPO from stable_baselines3, with some wrappers model_path = "saved_model.zip" learning_algorithm = "PPO" wrappers = { "observation_wrappers": [MyObservationWrapper], "action_wrappers": [MyActionWrapper], } library = "stable_baselines3" trained_policy = RLPolicy( action_state, model_path, learning_algorithm, env, wrappers, library ) .. note :: Currently only supports policies obtained via stable baselines 3. :param action_state: see ``BasePolicy`` :type action_state: see ``BasePolicy`` :param model_path: path to the saved model :type model_path: string :param learning_algorithm: name of the learning algorithm :type learning_algorithm: string :param env: environment before any wrappers were applied :type env: gym.Env :param wrappers: observation and action wrappers :type wrappers: dictionary :param library: name of the training library. Currently, only stable_baselines3 is supported. :type library: string """ def __init__( self, action_state, model_path, learning_algorithm, env, wrappers, library, *args, **kwargs ): model_path = model_path self.learning_algorithm = learning_algorithm self.env = env self.obs_wraps = wrappers["observation_wrappers"] self.act_wraps = wrappers["action_wrappers"] self.library = library # self.wrappers = kwargs.get("wrappers") if library != "stable_baselines3": raise NotImplementedError( "The Reinforcement Learning Policy currently only supports policies obtained via stable baselines 3." ) import stable_baselines3 learning_algorithm = getattr(stable_baselines3, learning_algorithm) self.model = learning_algorithm.load(model_path) # Recovering action space super().__init__(*args, action_state=action_state, **kwargs) @BasePolicy.default_value def sample(self, agent_observation=None, agent_state=None): """sample Get action by using model.predict(deterministic = True), applying the necessary wrappers. :param observation: see ``BasePolicy`` :type observation: see ``BasePolicy``, optional :return: see ``BasePolicy`` :rtype: see ``BasePolicy`` """ # convert observation via the Train class agent_observation = self.env._convertor.filter_gamestate(agent_observation) # Apply observation Wrappers env = self.env for w in self.obs_wraps: agent_observation = w(env).observation(agent_observation) env = w(env) action = self.model.predict(agent_observation, deterministic=True)[ 0 ] # with deterministic = True, don't sample from the Gaussian but just take its mean # Apply Action Wrappers for w in self.act_wraps: action = w(env).action(action) env = w(env) # action = list(action.values()) action = action.values() return action, 0