initialize_odefilter_with_taylormode¶
- probnum.diffeq.initialize_odefilter_with_taylormode(f, y0, t0, prior_process)[source]¶
Initialize an ODE filter with Taylor-mode automatic differentiation.
This requires JAX. For an explanation of what happens
under the hood
, see 1.References
- 1
Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, arXiv:2012.10106, 2020.
The implementation is inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py
- Parameters
f – ODE vector field.
y0 – Initial value.
t0 – Initial time point.
prior_process – Prior Gauss-Markov process used for the ODE solver. For instance an integrated Brownian motion prior (
IBM
).
- Returns
Estimated initial random variable. Compatible with the specified prior.
- Return type
Examples
>>> import sys, pytest >>> if not sys.platform.startswith('linux'): ... pytest.skip()
>>> from dataclasses import astuple >>> from probnum.randvars import Normal >>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax >>> from probnum.statespace import IBM >>> from probnum.randprocs import MarkovProcess
Compute the initial values of the restricted three-body problem as follows
>>> f, t0, tmax, y0, df, *_ = astuple(threebody_jax()) >>> print(y0) [ 0.994 0. 0. -2.00158511]
>>> prior = IBM(ordint=3, spatialdim=4) >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension)) >>> prior_process = MarkovProcess(transition=prior, initrv=initrv, initarg=t0) >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior_process=prior_process) >>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean) [ 0.994 0. 0. -2.00158511] >>> print(improved_initrv.mean) [ 9.94000000e-01 0.00000000e+00 -3.15543023e+02 0.00000000e+00 0.00000000e+00 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00 -3.15543023e+02 0.00000000e+00 6.39028111e+07 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00]
Compute the initial values of the van-der-Pol oscillator as follows
>>> f, t0, tmax, y0, df, *_ = astuple(vanderpol_jax()) >>> print(y0) [2. 0.] >>> prior = IBM(ordint=3, spatialdim=2) >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension)) >>> prior_process = MarkovProcess(transition=prior, initrv=initrv, initarg=t0) >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior_process=prior_process) >>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean) [2. 0.] >>> print(improved_initrv.mean) [ 2. 0. -2. 60. 0. -2. 60. -1798.] >>> print(improved_initrv.std) [0. 0. 0. 0. 0. 0. 0. 0.]