Source code for probnum.diffeq.perturbed.scipy_wrapper._wrapped_scipy_odesolution

"""Make a ProbNum ODE solution out of a scipy ODE solution."""
import numpy as np
from scipy.integrate._ivp.common import OdeSolution

from probnum import randvars
from probnum.diffeq import _odesolution
from probnum.filtsmooth._timeseriesposterior import DenseOutputValueType
from probnum.typing import ArrayLike


class WrappedScipyODESolution(_odesolution.ODESolution):
    """ODE solution corresponding to the :class:`WrappedScipyRungeKutta`."""

    def __init__(self, scipy_solution: OdeSolution, rvs: list):
        self.scipy_solution = scipy_solution

        # rvs is of the type `list` of `RandomVariable` and can therefore be
        # directly transformed into a _RandomVariableList
        rv_states = randvars._RandomVariableList(rvs)
        super().__init__(locations=scipy_solution.ts, states=rv_states)

[docs] def __call__(self, t: ArrayLike) -> DenseOutputValueType: """Evaluate the time-continuous solution at time t. Parameters ---------- t Location / time at which to evaluate the continuous ODE solution. Returns ------- randvars.RandomVariable or randvars._RandomVariableList Estimate of the states at time ``t`` based on a fourth order polynomial. """ states = self.scipy_solution(t).T if np.isscalar(t): solution_as_rv = randvars.Constant(states) else: solution_as_rv = randvars._RandomVariableList( [randvars.Constant(state) for state in states] ) return solution_as_rv