Source code for probnum.diffeq.odefilter._odefilter_solution

"""ODE solutions returned by Gaussian ODE filtering."""

from typing import Optional

import numpy as np

from probnum import filtsmooth, randvars, utils
from probnum.diffeq import _odesolution
from probnum.typing import ArrayLike, FloatLike, IntLike, ShapeLike


class ODEFilterSolution(_odesolution.ODESolution):
    """Probabilistic ODE solution corresponding to the :class:`ODEFilter`.

    Recall that in ProbNum, Gaussian filtering and smoothing is generally named "Kalman".

    Parameters
    ----------
    kalman_posterior
        Gauss-Markov posterior over the ODE solver state space model.
        Therefore, it assumes that the dynamics model is an :class:`Integrator`.

    See Also
    --------
    ODEFilter : ODE solver that behaves like a Gaussian filter.
    KalmanPosterior : Posterior over states after Gaussian filtering/smoothing.

    Examples
    --------
    >>> import numpy as np
    >>> from probnum.diffeq import probsolve_ivp
    >>> from probnum import randvars
    >>>
    >>> def f(t, x):
    ...     return 4*x*(1-x)
    >>>
    >>> y0 = np.array([0.15])
    >>> t0, tmax = 0., 1.5
    >>> solution = probsolve_ivp(f, t0, tmax, y0, step=0.1, adaptive=False)
    >>> # Mean of the discrete-time solution
    >>> print(np.round(solution.states.mean, 2))
    [[0.15]
     [0.21]
     [0.28]
     [0.37]
     [0.47]
     [0.57]
     [0.66]
     [0.74]
     [0.81]
     [0.87]
     [0.91]
     [0.94]
     [0.96]
     [0.97]
     [0.98]
     [0.99]]

    >>> # Times of the discrete-time solution
    >>> print(solution.locations)
    [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.  1.1 1.2 1.3 1.4 1.5]
    >>> # Individual entries of the discrete-time solution can be accessed with
    >>> print(solution[5])
    <Normal with shape=(1,), dtype=float64>
    >>> print(np.round(solution[5].mean, 2))
    [0.56]
    >>> # Evaluate the continuous-time solution at a new time point t=0.65
    >>> print(np.round(solution(0.65).mean, 2))
    [0.70]
    """

    def __init__(self, kalman_posterior: filtsmooth.gaussian.KalmanPosterior):
        self.kalman_posterior = kalman_posterior

        # Pre-compute projection matrices.
        # The prior must be an integrator, if not, an error is thrown in 'ODEFilter'.
        self.proj_to_y = self.kalman_posterior.transition.proj2coord(coord=0)
        self.proj_to_dy = self.kalman_posterior.transition.proj2coord(coord=1)

        states = randvars._RandomVariableList(
            [_project_rv(self.proj_to_y, rv) for rv in self.kalman_posterior.states]
        )
        derivatives = randvars._RandomVariableList(
            [_project_rv(self.proj_to_dy, rv) for rv in self.kalman_posterior.states]
        )
        super().__init__(
            locations=kalman_posterior.locations, states=states, derivatives=derivatives
        )

[docs] def interpolate( self, t: FloatLike, previous_index: Optional[IntLike] = None, next_index: Optional[IntLike] = None, ) -> randvars.RandomVariable: out_rv = self.kalman_posterior.interpolate( t, previous_index=previous_index, next_index=next_index ) return _project_rv(self.proj_to_y, out_rv)
[docs] def sample( self, rng: np.random.Generator, t: Optional[ArrayLike] = None, size: Optional[ShapeLike] = (), ) -> np.ndarray: samples = self.kalman_posterior.sample(rng=rng, t=t, size=size) # Project the samples down to the "true" ODEFilterSolution dimensions # (which are a subset of the KalmanPosterior dimensions) ode_samples = np.einsum("dq,...q->...d", self.proj_to_y, samples) return ode_samples
[docs] def transform_base_measure_realizations( self, base_measure_realizations: np.ndarray, t: ArrayLike = None, ) -> np.ndarray: errormsg = ( "The ODEFilterSolution does not implement transformation of realizations of a base measure." "Try `ODEFilterSolution.kalman_posterior.transform_base_measure_realizations` instead." ) raise NotImplementedError(errormsg)
@property def filtering_solution(self): if isinstance(self.kalman_posterior, filtsmooth.gaussian.FilteringPosterior): return self # else: self.kalman_posterior is a SmoothingPosterior object, which has the field filter_posterior. return ODEFilterSolution( kalman_posterior=self.kalman_posterior.filtering_posterior ) def _project_rv(projmat, rv): # There is no way of checking whether `rv` has its Cholesky factor computed already or not. # Therefore, since we need to update the Cholesky factor for square-root filtering, # we also update the Cholesky factor for non-square-root algorithms here, # which implies additional cost. # See Issues #319 and #329. # When they are resolved, this function here will hopefully be superfluous. new_mean = projmat @ rv.mean new_cov = projmat @ rv.cov @ projmat.T new_cov_cholesky = utils.linalg.cholesky_update(projmat @ rv.cov_cholesky) return randvars.Normal(new_mean, new_cov, cov_cholesky=new_cov_cholesky)