Source code for probnum.problems.zoo.diffeq._ivp_examples_jax

"""IVP examples that use jax."""

import numpy as np

from probnum.problems import InitialValueProblem

__all__ = ["threebody_jax", "vanderpol_jax"]
# pylint: disable=import-outside-toplevel

[docs]def threebody_jax(tmax=17.0652165601579625588917206249):
r"""Initial value problem (IVP) based on a three-body problem.

Let the initial conditions be :math:y = (y_1, y_2, \dot{y}_1, \dot{y}_2)^T.
This function implements the second-order three-body problem as a system of
first-order ODEs, which is defined as follows: [1]_

.. math::

f(t, y) =
\begin{pmatrix}
\dot{y_1} \\
\dot{y_2} \\
y_1 + 2 \dot{y}_2 - \frac{(1 - \mu) (y_1 + \mu)}{d_1}
- \frac{\mu (y_1 - (1 - \mu))}{d_2} \\
y_2 - 2 \dot{y}_1 - \frac{(1 - \mu) y_2}{d_1} - \frac{\mu y_2}{d_2}
\end{pmatrix}

with

.. math::

d_1 &= ((y_1 + \mu)^2 + y_2^2)^{\frac{3}{2}} \\
d_2 &= ((y_1 - (1 - \mu))^2 + y_2^2)^{\frac{3}{2}}

and a constant parameter  :math:\mu = 0.012277471
denoting the standardized moon mass.

Parameters
----------
tmax
Final time.

Returns
-------
IVP
IVP object describing a three-body problem IVP with the prescribed
configuration.

References
----------
.. [1] Hairer, E., Norsett, S. and Wanner, G..
Solving Ordinary Differential Equations I.
Springer Series in Computational Mathematics, 1993.
"""

jax, jnp = _import_jax()

def threebody_rhs(Y):
# defining the ODE:
# assume Y = [y1,y2,y1',y2']
mu = 0.012277471  # a constant (standardised moon mass)
mp = 1 - mu
D1 = ((Y[0] + mu) ** 2 + Y[1] ** 2) ** (3 / 2)
D2 = ((Y[0] - mp) ** 2 + Y[1] ** 2) ** (3 / 2)
y1p = Y[0] + 2 * Y[3] - mp * (Y[0] + mu) / D1 - mu * (Y[0] - mp) / D2
y2p = Y[1] - 2 * Y[2] - mp * Y[1] / D1 - mu * Y[1] / D2
return jnp.array([Y[2], Y[3], y1p, y2p])

df = jax.jacfwd(threebody_rhs)
ddf = jax.jacrev(df)

@jax.jit
def rhs(t, y):
return threebody_rhs(Y=y)

@jax.jit
def jac(t, y):
return df(y)

@jax.jit
def hess(t, y):
return ddf(y)

y0 = np.array([0.994, 0, 0, -2.00158510637908252240537862224])
t0 = 0.0
return InitialValueProblem(f=rhs, t0=t0, tmax=tmax, y0=y0, df=jac, ddf=hess)

def _import_jax():
errormsg = (
"IVP instantiation requires jax. "
"Try using the pure numpy versions instead, "
"or install jax via pip install jax jaxlib"
)

try:
import jax
from jax.config import config
import jax.numpy as jnp

config.update("jax_enable_x64", True)
return jax, jnp

except ImportError as err:
raise ImportError(errormsg) from err

[docs]def vanderpol_jax(t0=0.0, tmax=30, y0=None, params=1e1):
r"""Initial value problem (IVP) based on the Van der Pol Oscillator,
implemented in jax.

This function implements the second-order Van-der-Pol Oscillator as a system
of first-order ODEs.
The Van der Pol Oscillator is defined as

.. math::

f(t, y) =
\begin{pmatrix}
y_2 \\
\mu \cdot (1 - y_1^2)y_2 - y_1
\end{pmatrix}

for a constant parameter  :math:\mu.
:math:\mu determines the stiffness of the problem, where
the larger :math:\mu is chosen, the more stiff the problem becomes.
Default is :math:\mu = 0.1.
This implementation includes the Jacobian :math:J_f of :math:f.

Parameters
----------
t0 : float
Initial time point. Leftmost point of the integration domain.
tmax : float
Final time point. Rightmost point of the integration domain.
y0 : np.ndarray,
*(shape=(2, ))* -- Initial value of the problem.
params : (float), optional
Parameter :math:\mu for the Van der Pol Equations
Default is :math:\mu=0.1.

Returns
-------
IVP
IVP object describing the Van der Pol Oscillator IVP with the prescribed
configuration.
"""
jax, jnp = _import_jax()

if isinstance(params, float):
mu = params
else:
(mu,) = params

if y0 is None:
y0 = np.array([2.0, 0.0])

def vanderpol_rhs(Y):
return jnp.array([Y[1], mu * (1.0 - Y[0] ** 2) * Y[1] - Y[0]])

df = jax.jacfwd(vanderpol_rhs)
ddf = jax.jacrev(df)

@jax.jit
def rhs(t, y):
return vanderpol_rhs(Y=y)

@jax.jit
def jac(t, y):
return df(y)

@jax.jit
def hess(t, y):
return ddf(y)

return InitialValueProblem(f=rhs, t0=t0, tmax=tmax, y0=y0, df=jac, ddf=hess)