TaylorMode

class probnum.diffeq.odefilter.init_routines.TaylorMode[source]

Bases: _AutoDiffBase

Initialize a probabilistic ODE solver with Taylor-mode automatic differentiation.

This requires JAX. For an explanation of what happens under the hood, see 1.

References

1

Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, arXiv:2012.10106, 2020.

Examples

>>> import sys, pytest
>>> if not sys.platform.startswith('linux'):
...     pytest.skip()
>>> import numpy as np
>>> from probnum.randvars import Normal
>>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax
>>> from probnum.randprocs.markov.integrator import IntegratedWienerProcess

Compute the initial values of the restricted three-body problem as follows

>>> ivp = threebody_jax()
>>> print(ivp.y0)
[ 0.994       0.          0.         -2.00158511]

Construct the prior process.

>>> prior_process = IntegratedWienerProcess(
...     initarg=ivp.t0, wiener_process_dimension=4, num_derivatives=3
... )

Initialize with Taylor-mode autodiff.

>>> taylor_init = TaylorMode()
>>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)

Print the results.

>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[ 0.994       0.          0.         -2.00158511]
>>> print(improved_initrv.mean)
[ 9.94000000e-01  0.00000000e+00 -3.15543023e+02  0.00000000e+00
  0.00000000e+00 -2.00158511e+00  0.00000000e+00  9.99720945e+04
  0.00000000e+00 -3.15543023e+02  0.00000000e+00  6.39028111e+07
 -2.00158511e+00  0.00000000e+00  9.99720945e+04  0.00000000e+00]

Compute the initial values of the van-der-Pol oscillator as follows. First, set up the IVP and prior process.

>>> ivp = vanderpol_jax()
>>> print(ivp.y0)
[2. 0.]
>>> prior_process = IntegratedWienerProcess(
...     initarg=ivp.t0, wiener_process_dimension=2, num_derivatives=3
... )
>>> taylor_init = TaylorMode()
>>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)

Print the results.

>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[2. 0.]
>>> print(improved_initrv.mean)
[    2.     0.    -2.    60.     0.    -2.    60. -1798.]
>>> print(improved_initrv.std)
[0. 0. 0. 0. 0. 0. 0. 0.]

Attributes Summary

is_exact

Exactness of the computed initial values.

requires_jax

Whether the implementation of the routine relies on JAX.

Methods Summary

__call__(*, ivp, prior_process)

Call self as a function.

Attributes Documentation

is_exact

Exactness of the computed initial values.

Some initialization routines yield the exact initial derivatives, some others only yield approximations.

requires_jax

Whether the implementation of the routine relies on JAX.

Methods Documentation

__call__(*, ivp, prior_process)

Call self as a function.

Parameters
Return type

RandomVariable