TaylorMode¶
- class probnum.diffeq.odefilter.init_routines.TaylorMode[source]¶
Bases:
_AutoDiffBase
Initialize a probabilistic ODE solver with Taylor-mode automatic differentiation.
This requires JAX. For an explanation of what happens
under the hood
, see 1.References
- 1
Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, arXiv:2012.10106, 2020.
Examples
>>> import sys, pytest >>> if not sys.platform.startswith('linux'): ... pytest.skip()
>>> import numpy as np >>> from probnum.randvars import Normal >>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax >>> from probnum.randprocs.markov.integrator import IntegratedWienerProcess
Compute the initial values of the restricted three-body problem as follows
>>> ivp = threebody_jax() >>> print(ivp.y0) [ 0.994 0. 0. -2.00158511]
Construct the prior process.
>>> prior_process = IntegratedWienerProcess( ... initarg=ivp.t0, wiener_process_dimension=4, num_derivatives=3 ... )
Initialize with Taylor-mode autodiff.
>>> taylor_init = TaylorMode() >>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)
Print the results.
>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean) [ 0.994 0. 0. -2.00158511] >>> print(improved_initrv.mean) [ 9.94000000e-01 0.00000000e+00 -3.15543023e+02 0.00000000e+00 0.00000000e+00 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00 -3.15543023e+02 0.00000000e+00 6.39028111e+07 -2.00158511e+00 0.00000000e+00 9.99720945e+04 0.00000000e+00]
Compute the initial values of the van-der-Pol oscillator as follows. First, set up the IVP and prior process.
>>> ivp = vanderpol_jax() >>> print(ivp.y0) [2. 0.] >>> prior_process = IntegratedWienerProcess( ... initarg=ivp.t0, wiener_process_dimension=2, num_derivatives=3 ... )
>>> taylor_init = TaylorMode() >>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)
Print the results.
>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean) [2. 0.] >>> print(improved_initrv.mean) [ 2. 0. -2. 60. 0. -2. 60. -1798.] >>> print(improved_initrv.std) [0. 0. 0. 0. 0. 0. 0. 0.]
Attributes Summary
Exactness of the computed initial values.
Whether the implementation of the routine relies on JAX.
Methods Summary
__call__
(*, ivp, prior_process)Call self as a function.
Attributes Documentation
- is_exact¶
Exactness of the computed initial values.
Some initialization routines yield the exact initial derivatives, some others only yield approximations.
- requires_jax¶
Whether the implementation of the routine relies on JAX.
Methods Documentation
- __call__(*, ivp, prior_process)¶
Call self as a function.
- Parameters
ivp (InitialValueProblem) –
prior_process (MarkovProcess) –
- Return type