Source code for probnum.diffeq.odesolver

"""Abstract ODESolver class.

Interface for Runge-Kutta, ODEFilter.
"""

import warnings
from abc import ABC, abstractmethod

from probnum.diffeq.odesolution import ODESolution


class ODESolver(ABC):
    """Interface for ODESolver."""

    def __init__(self, ivp):
        self.ivp = ivp
        self.num_steps = 0

[docs] def solve(self, firststep, steprule, **kwargs): """Solve an IVP. Parameters ---------- firststep : float First step for adaptive step-size rule. steprule : :class:`StepRule` Step-size selection rule, e.g. constant steps or adaptive steps. """ t, current_rv = self.initialise() times, rvs = [t], [current_rv] stepsize = firststep while t < self.ivp.tmax: t_new = t + stepsize proposed_rv, errorest = self.step(t, t_new, current_rv, **kwargs) if steprule.is_accepted(stepsize, errorest): self.num_steps += 1 self.method_callback( time=t_new, current_guess=proposed_rv, current_error=errorest ) t = t_new current_rv = proposed_rv times.append(t) rvs.append(current_rv) suggested_stepsize = self._suggest_step(stepsize, errorest, steprule) stepsize = min(suggested_stepsize, self.ivp.tmax - t) odesol = self.postprocess(times=times, rvs=rvs) return odesol
def _suggest_step(self, step, errorest, steprule): """Suggests step according to steprule and warns if step is extremely small. Raises ------ RuntimeWarning If suggested step is smaller than :math:`10^{-15}`. """ step = steprule.suggest(step, errorest) if step < 1e-15: warnmsg = "Stepsize is num. zero (%.1e)" % step warnings.warn(message=warnmsg, category=RuntimeWarning) return step
[docs] @abstractmethod def initialise(self): """Returns t0 and y0 (for the solver, which might be different to ivp.y0)""" raise NotImplementedError
[docs] @abstractmethod def step(self, start, stop, current, **kwargs): """Every ODE solver needs a step() method that returns a new random variable and an error estimate.""" raise NotImplementedError
[docs] def postprocess(self, times, rvs): """Turn list of random variables into an ODE solution object and potentially do more. Overwrite for instance via. >>> def postprocess(self, times, rvs): >>> # do something with times and rvs >>> odesol = super().postprocess(times, rvs) >>> # do something with odesol >>> return odesol """ return ODESolution(times, rvs, self)
[docs] def method_callback(self, time, current_guess, current_error): """Optional callback. Can be overwritten. Do this as soon as it is clear that the current guess is accepted, but before storing it. No return. For example: tune hyperparameters (sigma). """ pass