from coopihc.base.State import State
from coopihc.base.elements import cat_element
from coopihc.policy.BasePolicy import BasePolicy
# ============== General Policies ===============
[docs]class DualPolicy(BasePolicy):
def __init__(
self, primary_policy, dual_policy, primary_kwargs={}, dual_kwargs={}, **kwargs
):
if type(primary_policy).__name__ == "type":
self.primary_policy = primary_policy(**primary_kwargs)
else:
self.primary_policy = primary_policy
if type(dual_policy).__name__ == "type":
self.dual_policy = dual_policy(**dual_kwargs)
else:
self.dual_policy = dual_policy
super().__init__()
self._host = None
self._action_state = None
delattr(self, "action_state")
delattr(self, "host")
self._mode = "primary"
@property
def mode(self):
return self._mode
@property
def host(self):
return self._host
@host.setter
def host(self, value):
self.primary_policy.host = value
self.dual_policy.host = value
@property
def action_state(self):
if self._mode == "primary":
return self.primary_policy.action_state
else:
return self.dual_policy.action_state
def _base_sample(self):
action, reward = self.sample(observation=None)
self.action = action
return self.action, reward
@BasePolicy.default_value
def sample(self, agent_observation=None, agent_state=None):
if self.mode == "primary":
return self.primary_policy.sample(
agent_observation=agent_observation, agent_state=agent_state
)
else:
return self.dual_policy.sample(
agent_observation=agent_observation, agent_state=agent_state
)