Source code for probnum.diffeq.perturbed.step._perturbedstepsolution

"""Output of PerturbedStepSolver."""

from typing import List, Optional

import numpy as np
from scipy.integrate._ivp import rk

from probnum import randvars
from probnum.diffeq import _odesolution
from probnum.typing import FloatLike


[docs]class PerturbedStepSolution(_odesolution.ODESolution): """Probabilistic ODE solution corresponding to the :class:`PerturbedStepSolver`.""" def __init__( self, scales: List[float], locations: np.ndarray, states: randvars._RandomVariableList, interpolants: List[rk.DenseOutput], ): self.scales = scales self.interpolants = interpolants super().__init__(locations, states)
[docs] def interpolate( self, t: FloatLike, previous_index: Optional[FloatLike] = None, next_index: Optional[FloatLike] = None, ): # For the first state, no interpolation has to be performed. if t == self.locations[0]: return self.states[0] if t == self.locations[-1]: return self.states[-1] interpolant = self.interpolants[previous_index] relative_time = (t - self.locations[previous_index]) * self.scales[ previous_index ] previous_time = self.locations[previous_index] evaluation = interpolant(previous_time + relative_time) res_as_rv = randvars.Constant(evaluation) return res_as_rv