Source code for probnum.filtsmooth.gaussfiltsmooth.kalman

"""
Kalman filtering and (Rauch-Tung-Striebel) smoothing for
continuous-discrete and discrete-discrete state space models.
"""

import numpy as np

from probnum.filtsmooth.gaussfiltsmooth._utils import is_cont_disc, is_disc_disc
from probnum.filtsmooth.gaussfiltsmooth.gaussfiltsmooth import (
    GaussFiltSmooth,
    linear_discrete_update,
)
from probnum.filtsmooth.statespace import DiscreteGaussianLinearModel, LinearSDEModel
from probnum.random_variables import Normal


[docs]class Kalman(GaussFiltSmooth): """ Kalman filtering and smoothing for continuous-discrete and discrete-discrete state space models. """ def __new__(cls, dynamod, measmod, initrv, **kwargs): """ Factory method for Kalman filtering and smoothing. Depending on whether the dynamic model is continuous or discrete, either a continuous-discrete Kalman object or a discrete-discrete Kalman object is created. """ if cls is Kalman: if is_cont_disc(dynamod, measmod): return _ContDiscKalman(dynamod, measmod, initrv, **kwargs) if is_disc_disc(dynamod, measmod): return _DiscDiscKalman(dynamod, measmod, initrv) else: errmsg = ( "Cannot instantiate Kalman object with given " "dynamic model and measurement model." ) raise ValueError(errmsg) else: return super().__new__(cls)
class _ContDiscKalman(Kalman): """ Provides predict() and update() methods for Kalman filtering and smoothing on continuous-discrete state space models. """ def __init__(self, dynamod, measmod, initrv, **kwargs): """ Checks that dynamod and measmod are linear and moves on. """ if not issubclass(type(dynamod), LinearSDEModel): raise ValueError( "ContinuousDiscreteKalman requires a linear dynamic model." ) if not issubclass(type(measmod), DiscreteGaussianLinearModel): raise ValueError( "ContinuousDiscreteKalman requires a linear measurement model." ) if "cke_nsteps" in kwargs.keys(): self.cke_nsteps = kwargs["cke_nsteps"] else: self.cke_nsteps = 1 super().__init__(dynamod, measmod, initrv) def predict(self, start, stop, randvar, **kwargs): step = (stop - start) / self.cke_nsteps return self.dynamicmodel.transition_rv( rv=randvar, start=start, stop=stop, step=step ) def update(self, time, randvar, data, **kwargs): return _discrete_kalman_update( time, randvar, data, self.measurementmodel, **kwargs ) class _DiscDiscKalman(Kalman): """ Provides predict() and update() methods for Kalman filtering and smoothing on discrete-discrete state space models. """ def __init__(self, dynamod, measmod, initrv): """Checks that dynamod and measmod are linear and moves on.""" if not issubclass(type(dynamod), DiscreteGaussianLinearModel): raise ValueError( "ContinuousDiscreteKalman requires " "a linear dynamic model." ) if not issubclass(type(measmod), DiscreteGaussianLinearModel): raise ValueError( "DiscreteDiscreteKalman requires " "a linear measurement model." ) super().__init__(dynamod, measmod, initrv) def predict(self, start, stop, randvar, **kwargs): """Prediction step for discrete-discrete Kalman filtering.""" mean, covar = randvar.mean, randvar.cov if np.isscalar(mean) and np.isscalar(covar): mean, covar = mean * np.ones(1), covar * np.eye(1) dynamat = self.dynamicmodel.dynamicsmatrix(start, **kwargs) forcevec = self.dynamicmodel.forcevector(start, **kwargs) diffmat = self.dynamicmodel.diffusionmatrix(start, **kwargs) mpred = dynamat @ mean + forcevec ccpred = covar @ dynamat.T cpred = dynamat @ ccpred + diffmat return Normal(mpred, cpred), {"crosscov": ccpred} def update(self, time, randvar, data, **kwargs): """Update step of discrete Kalman filtering""" return _discrete_kalman_update( time, randvar, data, self.measurementmodel, **kwargs ) def _discrete_kalman_update(time, randvar, data, measurementmodel, **kwargs): """Discrete Kalman update.""" mpred, cpred = randvar.mean, randvar.cov if np.isscalar(mpred) and np.isscalar(cpred): mpred, cpred = mpred * np.ones(1), cpred * np.eye(1) measmat = measurementmodel.dynamicsmatrix(time, **kwargs) meascov = measurementmodel.diffusionmatrix(time, **kwargs) meanest = measmat @ mpred return linear_discrete_update(meanest, cpred, data, meascov, measmat, mpred)