initialize_odefilter_with_taylormode

probnum.diffeq.initialize_odefilter_with_taylormode(f, y0, t0, prior_process)[source]

Initialize an ODE filter with Taylor-mode automatic differentiation.

This requires JAX. For an explanation of what happens under the hood, see 1.

References

1

Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, arXiv:2012.10106, 2020.

The implementation is inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py

Parameters
  • f – ODE vector field.

  • y0 – Initial value.

  • t0 – Initial time point.

  • prior_process – Prior Gauss-Markov process used for the ODE solver. For instance an integrated Brownian motion prior (IBM).

Returns

Estimated initial random variable. Compatible with the specified prior.

Return type

Normal

Examples

>>> import sys, pytest
>>> if not sys.platform.startswith('linux'):
...     pytest.skip()
>>> from dataclasses import astuple
>>> from probnum.randvars import Normal
>>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax
>>> from probnum.statespace import IBM
>>> from probnum.randprocs import MarkovProcess

Compute the initial values of the restricted three-body problem as follows

>>> f, t0, tmax, y0, df, *_ = astuple(threebody_jax())
>>> print(y0)
[ 0.994       0.          0.         -2.00158511]
>>> prior = IBM(ordint=3, spatialdim=4)
>>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
>>> prior_process = MarkovProcess(transition=prior, initrv=initrv, initarg=t0)
>>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior_process=prior_process)
>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[ 0.994       0.          0.         -2.00158511]
>>> print(improved_initrv.mean)
[ 9.94000000e-01  0.00000000e+00 -3.15543023e+02  0.00000000e+00
  0.00000000e+00 -2.00158511e+00  0.00000000e+00  9.99720945e+04
  0.00000000e+00 -3.15543023e+02  0.00000000e+00  6.39028111e+07
 -2.00158511e+00  0.00000000e+00  9.99720945e+04  0.00000000e+00]

Compute the initial values of the van-der-Pol oscillator as follows

>>> f, t0, tmax, y0, df, *_ = astuple(vanderpol_jax())
>>> print(y0)
[2. 0.]
>>> prior = IBM(ordint=3, spatialdim=2)
>>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
>>> prior_process = MarkovProcess(transition=prior, initrv=initrv, initarg=t0)
>>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior_process=prior_process)
>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[2. 0.]
>>> print(improved_initrv.mean)
[    2.     0.    -2.    60.     0.    -2.    60. -1798.]
>>> print(improved_initrv.std)
[0. 0. 0. 0. 0. 0. 0. 0.]