Source code for probnum.diffeq.odefiltsmooth.kalman_odesolution

"""ODE solutions returned by Gaussian ODE filtering."""

import typing

import numpy as np

import probnum._randomvariablelist as pnrv_list
import probnum.filtsmooth as pnfs
import probnum.type
import probnum.utils
from probnum import randvars
from probnum.utils.linalg import cholesky_update

from ..odesolution import ODESolution

try:
    # functools.cached_property is only available in Python >=3.8
    from functools import cached_property
except ImportError:
    from cached_property import cached_property


[docs]class KalmanODESolution(ODESolution): """Gaussian IVP filtering solution of an ODE problem. Recall that in ProbNum, Gaussian filtering and smoothing is generally named "Kalman". Parameters ---------- kalman_posterior Gauss-Markov posterior over the ODE solver state space model. Therefore, it assumes that the dynamics model is an :class:`Integrator`. See Also -------- GaussianIVPFilter : ODE solver that behaves like a Gaussian filter. KalmanPosterior : Posterior over states after Gaussian filtering/smoothing. Examples -------- >>> from probnum.diffeq import logistic, probsolve_ivp >>> from probnum import randvars >>> >>> def f(t, x): ... return 4*x*(1-x) >>> >>> y0 = np.array([0.15]) >>> t0, tmax = 0., 1.5 >>> solution = probsolve_ivp(f, t0, tmax, y0, step=0.1, adaptive=False) >>> # Mean of the discrete-time solution >>> print(np.round(solution.y.mean, 2)) [[0.15] [0.21] [0.28] [0.37] [0.47] [0.57] [0.66] [0.74] [0.81] [0.87] [0.91] [0.94] [0.96] [0.97] [0.98] [0.99]] >>> # Times of the discrete-time solution >>> print(solution.t) [0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. 1.1 1.2 1.3 1.4 1.5] >>> # Individual entries of the discrete-time solution can be accessed with >>> print(solution[5]) <Normal with shape=(1,), dtype=float64> >>> print(np.round(solution[5].mean, 2)) [0.56] >>> # Evaluate the continuous-time solution at a new time point t=0.65 >>> print(np.round(solution(0.65).mean, 2)) [0.70] """ def __init__(self, kalman_posterior: pnfs.KalmanPosterior): self.kalman_posterior = kalman_posterior # Pre-compute projection matrices. # The prior must be an integrator, if not, an error is thrown in 'GaussianIVPFilter'. self.proj_to_y = kalman_posterior.transition.proj2coord(coord=0) self.proj_to_dy = kalman_posterior.transition.proj2coord(coord=1) @property def t(self) -> np.ndarray: return self.kalman_posterior.locations @cached_property def y(self) -> pnrv_list._RandomVariableList: y_rvs = [ _project_rv(self.proj_to_y, rv) for rv in self.kalman_posterior.state_rvs ] return pnrv_list._RandomVariableList(y_rvs) @cached_property def dy(self) -> pnrv_list._RandomVariableList: dy_rvs = [ _project_rv(self.proj_to_dy, rv) for rv in self.kalman_posterior.state_rvs ] return pnrv_list._RandomVariableList(dy_rvs)
[docs] def __call__( self, t: float ) -> typing.Union[randvars.RandomVariable, pnrv_list._RandomVariableList]: out_rv = self.kalman_posterior(t) if np.isscalar(t): return _project_rv(self.proj_to_y, out_rv) return pnrv_list._RandomVariableList( [_project_rv(self.proj_to_y, rv) for rv in out_rv] )
[docs] def sample( self, t: typing.Optional[float] = None, size: typing.Optional[probnum.type.ShapeArgType] = (), ) -> np.ndarray: """Sample from the Gaussian filtering ODE solution by sampling from the Gauss- Markov posterior.""" size = probnum.utils.as_shape(size) # implement only single samples, rest via recursion # We cannot 'steal' the recursion from self.kalman_posterior.sample, # because we need to project the respective states out of each sample. if size != (): return np.array([self.sample(t=t, size=size[1:]) for _ in range(size[0])]) samples = self.kalman_posterior.sample(locations=t, size=size) return np.array([self.proj_to_y @ sample for sample in samples])
@property def filtering_solution(self): if isinstance(self.kalman_posterior, pnfs.FilteringPosterior): return self # else: self.kalman_posterior is a SmoothingPosterior object, which has the field filter_posterior. return KalmanODESolution( kalman_posterior=self.kalman_posterior.filtering_posterior )
def _project_rv(projmat, rv): # There is no way of checking whether `rv` has its Cholesky factor computed already or not. # Therefore, since we need to update the Cholesky factor for square-root filtering, # we also update the Cholesky factor for non-square-root algorithms here, # which implies additional cost. # See Issues #319 and #329. # When they are resolved, this function here will hopefully be superfluous. new_mean = projmat @ rv.mean new_cov = projmat @ rv.cov @ projmat.T new_cov_cholesky = cholesky_update(projmat @ rv.cov_cholesky) return randvars.Normal(new_mean, new_cov, cov_cholesky=new_cov_cholesky)