Event handling and callbacks in ODE solvers

The differential equation solvers in ProbNum are able to handle events. At the moment, an event can either be a set of grid-points that must be included in the posterior, or a state for which a condition-function

\[ \begin{align}\begin{aligned} \text{condition}: \mathbb{R}^d \rightarrow \{0, 1\},\\evaluates to ``True``. This notebook explains how this can be used with ProbNum (some examples are taken from https://diffeq.sciml.ai/stable/features/callback\_functions/)\end{aligned}\end{align} \]

Quickstart

What is the easiest way to force events into your ODE solution? Let us define a simple, linear ODE that describes exponential decay.

[1]:
# Make inline plots vector graphics instead of raster graphics
%matplotlib inline
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "svg")

# Plotting
import matplotlib.pyplot as plt

plt.style.use("../../probnum.mplstyle")
[2]:
from probnum import diffeq, randvars, randprocs, problems
import numpy as np

# For easy modification of the states in the callbacks
import dataclasses
[3]:
def f(t, y):
    return -y


def df(t, y):
    return -1.0 * np.eye(len(y))  # np.ones((len(y), len(y)))


t0 = 0.0
tmax = 5.0
y0 = np.array([4])

To show off the ability to include a set number of grid-points, let us define a dense grid in a subset of the integration domain.

[4]:
time_stops = np.linspace(3.5, 4.0, 50)

To force the ODE solver to include these time-stamps, just pass them to probsolve_ivp. Here, we pick a large relative tolerance because we want to see a range of samples (the ODE is so simple, it is solved very accurately on large steps).

[5]:
probsol = diffeq.probsolve_ivp(
    f=f,
    t0=t0,
    tmax=tmax,
    y0=y0,
    time_stops=time_stops,
    rtol=0.8,
)

# Draw 10 samples from the posterior and plot.
rng = np.random.default_rng(seed=2)
samples = probsol.sample(size=10, rng=rng)
for sample in samples:
    plt.plot(probsol.locations, sample, "o-", color="C0")
plt.show()
../../_images/tutorials_odes_event_handling_7_0.svg

Observe how there is a dense gathering of grid-points between 3.5 and 4.0. These are our events!

The same works for e.g. perturbsolve_ivp. Let us compute 10 perturbed solutions, so the plots look similar to the samples from the posterior of the probabilistic solver.

[6]:
# every solve is random
rng = np.random.default_rng()


time_stops = np.linspace(3.5, 4.0, 100)
perturbsols = [
    diffeq.perturbsolve_ivp(
        f=f,
        t0=t0,
        tmax=tmax,
        y0=y0,
        rng=rng,
        noise_scale=0.05,
        time_stops=time_stops,
    )
    for _ in range(10)
]

for perturbsol in perturbsols:
    plt.plot(perturbsol.locations, perturbsol.states.mean, "o-", color="C1")
plt.show()
../../_images/tutorials_odes_event_handling_9_0.svg

Again, observe how there are many locations between 3.5 and 4.0.

Discrete callback events

It is also possible to modify the solver states whenever an event happens. This is not possible via the top-level interface functions (e.g. probsolve_ivp) - we have to build an ODE solver from scratch (see the respective notebook for an explanation thereof).

[7]:
# Construct IVP, prior, linearization, diffusion, and initialization
ivp = problems.InitialValueProblem(t0=t0, tmax=tmax, y0=y0, f=f, df=df)
prior_process = randprocs.markov.integrator.IntegratedWienerProcess(
    initarg=ivp.t0,
    num_derivatives=1,
    wiener_process_dimension=ivp.dimension,
    forward_implementation="sqrt",
    backward_implementation="sqrt",
)
diffmodel = randprocs.markov.continuous.PiecewiseConstantDiffusion(t0=t0)
rk_init = diffeq.odefilter.initialization_routines.RungeKuttaInitialization()
ode_residual = diffeq.odefilter.information_operators.ODEResidual(1, ivp.dimension)
ek1 = diffeq.odefilter.approx_strategies.EK1()
firststep = diffeq.stepsize.propose_firststep(ivp)
steprule = diffeq.stepsize.AdaptiveSteps(firststep=firststep, atol=1e-1, rtol=1e-1)

solver = diffeq.odefilter.ODEFilter(
    steprule=steprule,
    prior_process=prior_process,
    with_smoothing=False,
)

To describe a discrete event, we define a condition function that checks whether the current time-point is either 2.0 or 4.0. At both locations, we reset the current state to \(y=6.\) (careful! The state of a filter-based solver consists of \([y, \dot y, \ddot y, ...]\)).

Let us construct both functions and pass them to a DiscreteEventHandler. Since the solver is unlikely to stop at exactly 2.0 or 4.0, let us force these locations into the ODE solver posterior.

[8]:
def condition(state: diffeq.ODESolverState) -> bool:
    return state.t in [2.0, 4.0]


def replace(state: diffeq.ODESolverState) -> diffeq.ODESolverState:
    """Replace an ODE solver state whenever a condition is True."""
    new_mean = np.array([6.0, -6])
    new_rv = randvars.Normal(
        new_mean, cov=0 * state.rv.cov, cov_cholesky=0 * state.rv.cov_cholesky
    )
    return dataclasses.replace(state, rv=new_rv)


callback = diffeq.callbacks.DiscreteCallback(condition=condition, replace=replace)
odesol = solver.solve(ivp=ivp, stop_at=[2.0, 4.0], callbacks=callback)
[9]:
plt.plot(odesol.locations, odesol.states.mean, "o-")
plt.show()
../../_images/tutorials_odes_event_handling_15_0.svg
[ ]: