import numpy
from coopihc.inference.BaseInferenceEngine import BaseInferenceEngine
[docs]class ContinuousKalmanUpdate(BaseInferenceEngine):
An inference engine which estimates the new state according to a continuous kalman filter, where state transition dynamics and kalman gains are provided externally.
def action(self):
return super().action[0]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fmd_flag = False
self.K_flag = False
[docs] def set_forward_model_dynamics(self, A, B, C):
"""set forward model dynamics
Call this externally to supply the linear dynamic matrices that describe the deterministic part of the state transitions:
.. math::
d\\hat{x} = A\\hat{x}dt + Budt \\\\
dy = C \\hat{x} dt
:param A: see equation above
:type A: numpy.ndarray
:param B: see equation above
:type B: numpy.ndarray
:param C: see equation above
:type C: numpy.ndarray
self.fmd_flag = True
self.A = A
self.B = B
self.C = C
[docs] def set_K(self, K):
Set the Kalman gain
:param K: Kalman Gain
:type K: numpy.ndarray
self.K_flag = True
self.K = K
def infer(self, agent_observation=None):
Infer the state based on the observation.
:return: (new state, reward)
:rtype: tuple(py:class:`State<coopihc.base.State.State>`, float)
if not self.fmd_flag:
raise RuntimeError(
"You have to set the forward model dynamics, by calling the set_forward_model_dynamics() method with inference engine {} before using it".format(
if not self.K_flag:
raise RuntimeError(
"You have to set the K Matrix, by calling the set_K() method with inference engine {} before using it".format(
observation = self.observation
dy = observation["task_state"]["x"] *
if isinstance(dy, list):
dy = dy[0]
if not isinstance(dy, numpy.ndarray):
raise TypeError(
"Substate Xhat of {} is expected to be of type numpy.ndarray".format(
state = observation["{}_state".format(]
u = self.action.view(numpy.ndarray)
xhat = state["xhat"].view(numpy.ndarray)
xhat = xhat.reshape(-1, 1)
u = u.reshape(-1, 1)
deltaxhat = (self.A @ xhat + self.B @ u) * + self.K @ (
dy - self.C @ xhat *
xhat += deltaxhat
state["xhat"] = xhat
# Here, we use the classical definition of rewards in the LQG setup, but this requires having the true value of the state. This may or may not realistic...
# ====================== Rewards ===============
x =["x"].view(numpy.ndarray)
reward = (x - xhat).T @ @ (x - xhat)
return state, reward