Source code for probnum.diffeq.perturbed.scipy_wrapper._wrapped_scipy_solver

"""Wrapper class of scipy.integrate. for RK23 and RK45.

Dense-output can not be used for DOP853, if you use other RK-methods, make sure, that
the current implementation works for them.
"""
import numpy as np
from scipy.integrate._ivp import rk
from scipy.integrate._ivp.common import OdeSolution

from probnum import randvars
from probnum.diffeq import _odesolver, _odesolver_state
from probnum.diffeq.perturbed.scipy_wrapper import _wrapped_scipy_odesolution
from probnum.typing import FloatLike


class WrappedScipyRungeKutta(_odesolver.ODESolver):
    """Wrapper for Runge-Kutta methods from SciPy."""

    def __init__(self, solver_type: rk.RungeKutta, steprule):
        self.solver_type = solver_type
        self.interpolants = None

        # Filled in later
        self.solver = None
        self.ivp = None

        # Dopri853 as implemented in SciPy computes the dense output differently.
        if issubclass(solver_type, rk.DOP853):
            raise TypeError(
                "Dense output interpolation of DOP853 is currently not supported. Choose a different RK-method."
            )

        super().__init__(steprule=steprule, order=solver_type.order)

[docs] def initialize(self, ivp): """Return t0 and y0 (for the solver, which might be different to ivp.y0) and initialize the solver. Reset the solver when solving the ODE multiple times, i.e. explicitly setting y_old, t, y and f to the respective initial values, otherwise those are wrong when running the solver twice. Returns ------- self.ivp.t0: float initial time point self.ivp.initrv: randvars.RandomVariable initial random variable """ self.solver = self.solver_type(ivp.f, ivp.t0, ivp.y0, ivp.tmax) self.ivp = ivp self.interpolants = [] self.solver.y_old = None self.solver.t = self.ivp.t0 self.solver.y = self.ivp.y0 self.solver.f = self.solver.fun(self.solver.t, self.solver.y) state = _odesolver_state.ODESolverState( ivp=ivp, rv=randvars.Constant(self.ivp.y0), t=self.ivp.t0, error_estimate=None, reference_state=None, ) return state
[docs] def attempt_step(self, state: _odesolver_state.ODESolverState, dt: FloatLike): """Perform one ODE-step from start to stop and set variables to the corresponding values. To specify start and stop directly, rk_step() and not _step_impl() is used. Parameters ---------- state Current state of the ODE solver. dt Step-size. Returns ------- _odesolver_state.ODESolverState New state. """ y_new, f_new = rk.rk_step( self.solver.fun, state.t, state.rv.mean, self.solver.f, dt, self.solver.A, self.solver.B, self.solver.C, self.solver.K, ) # Unnormalized error estimation is used as the error estimation is normalized in # solve(). error_estimation = self.solver._estimate_error(self.solver.K, dt) y_new_as_rv = randvars.Constant(y_new) new_state = _odesolver_state.ODESolverState( ivp=state.ivp, rv=y_new_as_rv, t=state.t + dt, error_estimate=error_estimation, reference_state=state.rv.mean, ) # Update the solver settings. This part is copied from scipy's _step_impl(). self.solver.h_previous = dt self.solver.y_old = state.rv.mean self.solver.t_old = state.t self.solver.t = state.t + dt self.solver.y = y_new self.solver.h_abs = dt self.solver.f = f_new return new_state
[docs] def rvlist_to_odesol(self, times: np.array, rvs: np.array): """Create a ScipyODESolution object which is a subclass of diffeq.ODESolution.""" scipy_solution = OdeSolution(times, self.interpolants) probnum_solution = _wrapped_scipy_odesolution.WrappedScipyODESolution( scipy_solution, rvs ) return probnum_solution
[docs] def method_callback(self, state): """Call dense output after each step and store the interpolants.""" dense = self.dense_output() self.interpolants.append(dense)
[docs] def dense_output(self): """Compute the interpolant after each step. Returns ------- sol : rk.RkDenseOutput Interpolant between the last and current location. """ Q = self.solver.K.T.dot(self.solver.P) sol = rk.RkDenseOutput(self.solver.t_old, self.solver.t, self.solver.y_old, Q) return sol