
class probnum.diffeq.odefilter.initialization_routines.TaylorModeInitialization

Bases: probnum.diffeq.odefilter.initialization_routines.InitializationRoutine

Initialize a probabilistic ODE solver with Taylor-mode automatic differentiation.

This requires JAX. For an explanation of what happens under the hood, see 1.

The implementation is inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py See also 2.



Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, arXiv:2012.10106, 2020.


Kelly, J. and Bettencourt, J. and Johnson, M. and Duvenaud, D., Learning differential equations that are easy to solve, Neurips 2020.


>>> import sys, pytest
>>> if not sys.platform.startswith('linux'):
...     pytest.skip()
>>> import numpy as np
>>> from probnum.randvars import Normal
>>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax
>>> from probnum.randprocs.markov.integrator import IntegratedWienerProcess

Compute the initial values of the restricted three-body problem as follows

>>> ivp = threebody_jax()
>>> print(ivp.y0)
[ 0.994       0.          0.         -2.00158511]

Construct the prior process.

>>> prior_process = IntegratedWienerProcess(initarg=ivp.t0, wiener_process_dimension=4, num_derivatives=3)

Initialize with Taylor-mode autodiff.

>>> taylor_init = TaylorModeInitialization()
>>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)

Print the results.

>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[ 0.994       0.          0.         -2.00158511]
>>> print(improved_initrv.mean)
[ 9.94000000e-01  0.00000000e+00 -3.15543023e+02  0.00000000e+00
  0.00000000e+00 -2.00158511e+00  0.00000000e+00  9.99720945e+04
  0.00000000e+00 -3.15543023e+02  0.00000000e+00  6.39028111e+07
 -2.00158511e+00  0.00000000e+00  9.99720945e+04  0.00000000e+00]

Compute the initial values of the van-der-Pol oscillator as follows. First, set up the IVP and prior process.

>>> ivp = vanderpol_jax()
>>> print(ivp.y0)
[2. 0.]
>>> prior_process = IntegratedWienerProcess(initarg=ivp.t0, wiener_process_dimension=2, num_derivatives=3)
>>> taylor_init = TaylorModeInitialization()
>>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)

Print the results.

>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[2. 0.]
>>> print(improved_initrv.mean)
[    2.     0.    -2.    60.     0.    -2.    60. -1798.]
>>> print(improved_initrv.std)
[0. 0. 0. 0. 0. 0. 0. 0.]

Attributes Summary


Exactness of the computed initial values.


Whether the implementation of the routine relies on JAX.

Methods Summary

__call__(ivp, prior_process)

Call self as a function.

Attributes Documentation


Exactness of the computed initial values.

Some initialization routines yield the exact initial derivatives, some others only yield approximations.

Return type



Whether the implementation of the routine relies on JAX.

Return type


Methods Documentation

__call__(ivp, prior_process)[source]

Call self as a function.

Return type
