import copy
from coopihc.base.State import State
from coopihc.policy.BasePolicy import BasePolicy
from coopihc.bundle.wrappers.Train import TrainGym
# ======================= RL Policy
[docs]class RLPolicy(BasePolicy):
"""Wrap a trained net as a CoopIHC policy.
A policy object compatible with CoopIHC that wraps a policy that was trained via Deep Reinforcement learning.
Example code:
.. code-block:: python
# action_state
action_state = State()
action_state["action"] = StateElement(0, autospace([-5 + i for i in range(11)]))
# env
env = TrainGym(
bundle,
train_user=True,
train_assistant=False,
)
# Using PPO from stable_baselines3, with some wrappers
model_path = "saved_model.zip"
learning_algorithm = "PPO"
wrappers = {
"observation_wrappers": [MyObservationWrapper],
"action_wrappers": [MyActionWrapper],
}
library = "stable_baselines3"
trained_policy = RLPolicy(
action_state, model_path, learning_algorithm, env, wrappers, library
)
.. note ::
Currently only supports policies obtained via stable baselines 3.
:param action_state: see ``BasePolicy``
:type action_state: see ``BasePolicy``
:param model_path: path to the saved model
:type model_path: string
:param learning_algorithm: name of the learning algorithm
:type learning_algorithm: string
:param env: environment before any wrappers were applied
:type env: gym.Env
:param wrappers: observation and action wrappers
:type wrappers: dictionary
:param library: name of the training library. Currently, only stable_baselines3 is supported.
:type library: string
"""
def __init__(
self,
action_state,
model_path,
learning_algorithm,
env,
wrappers,
library,
*args,
**kwargs
):
model_path = model_path
self.learning_algorithm = learning_algorithm
self.env = env
self.obs_wraps = wrappers["observation_wrappers"]
self.act_wraps = wrappers["action_wrappers"]
self.library = library
# self.wrappers = kwargs.get("wrappers")
if library != "stable_baselines3":
raise NotImplementedError(
"The Reinforcement Learning Policy currently only supports policies obtained via stable baselines 3."
)
import stable_baselines3
learning_algorithm = getattr(stable_baselines3, learning_algorithm)
self.model = learning_algorithm.load(model_path)
# Recovering action space
super().__init__(*args, action_state=action_state, **kwargs)
@BasePolicy.default_value
def sample(self, agent_observation=None, agent_state=None):
"""sample
Get action by using model.predict(deterministic = True), applying the necessary wrappers.
:param observation: see ``BasePolicy``
:type observation: see ``BasePolicy``, optional
:return: see ``BasePolicy``
:rtype: see ``BasePolicy``
"""
# convert observation via the Train class
agent_observation = self.env._convertor.filter_gamestate(agent_observation)
# Apply observation Wrappers
env = self.env
for w in self.obs_wraps:
agent_observation = w(env).observation(agent_observation)
env = w(env)
action = self.model.predict(agent_observation, deterministic=True)[
0
] # with deterministic = True, don't sample from the Gaussian but just take its mean
# Apply Action Wrappers
for w in self.act_wraps:
action = w(env).action(action)
env = w(env)
# action = list(action.values())
action = action.values()
return action, 0