Source code for probnum._pnmethod._stopping_criterion

"""Stopping criterion for a probabilistic numerical method."""

import abc
from typing import Any, Callable

class StoppingCriterion(abc.ABC):
    """Stopping criterion of a probabilistic numerical method.

    Checks whether quantities tracked by the probabilistic numerical
    method meet a desired terminal condition.

    Stopping criteria support binary arithmetic, which makes them easy to combine.
    Take the following example, where we define a custom solver state.

    >>> import dataclasses
    >>> import numpy as np
    >>> from probnum import StoppingCriterion

    >>> @dataclasses.dataclass
    ... class SolverState:
    ...     iters = 50
    ...     atol = 1e-12
    ...     rtol = 1e-3

    >>> state = SolverState()

    Next we implement a few stopping criteria.

    >>> class MaxIterations(StoppingCriterion):
    ...     def __init__(self, maxiters):
    ...         self.maxiters = maxiters
    ...     def __call__(self, solver_state) -> bool:
    ...         return solver_state.iters >= self.maxiters

    >>> class AbsoluteResidualTolerance(StoppingCriterion):
    ...     def __init__(self, atol=1e-6):
    ...         self.atol = atol
    ...     def __call__(self, solver_state) -> bool:
    ...         return solver_state.atol < self.atol

    >>> class RelativeResidualTolerance(StoppingCriterion):
    ...     def __init__(self, rtol=1e-6):
    ...         self.rtol = rtol
    ...     def __call__(self, solver_state) -> bool:
    ...         return solver_state.rtol < self.rtol

    Now let's combine them by stopping when the solver has reached an absolute and relative tolerance, or a maximum number of iterations.

    >>> stopcrit = MaxIterations(maxiters=100) | (
    ...     AbsoluteResidualTolerance(atol=1e-6)
    ...     & RelativeResidualTolerance(rtol=1e-6)
    ... )

    >>> stopcrit(state)

    Now let's modify the state such that the solver has reached a maximum number of iterations.

    >>> state.iters = 1000
    >>> stopcrit(state)

    See Also
    LambdaStoppingCriterion : Stopping criterion defined via an anonymous function.
    ~probnum.linalg.solvers.stopping_criteria.LinearSolverStoppingCriterion : Stopping criterion of a probabilistic linear solver.
    ~probnum.filtsmooth.optim.FiltSmoothStoppingCriterion : Stopping criterion of filters and smoothers.

[docs] @abc.abstractmethod def __call__(self, *args, **kwargs) -> bool: """Check whether tracked quantities meet a desired terminal condition.""" raise NotImplementedError
def __and__(self, other): return LambdaStoppingCriterion( stopcrit=lambda *args, **kwargs: self(*args, **kwargs) and other(*args, **kwargs) ) def __or__(self, other): return LambdaStoppingCriterion( stopcrit=lambda *args, **kwargs: self(*args, **kwargs) or other(*args, **kwargs) ) def __invert__(self): return LambdaStoppingCriterion( stopcrit=lambda *args, **kwargs: not self(*args, **kwargs) ) class LambdaStoppingCriterion(StoppingCriterion): """Define a stopping criterion via an anonymous function. Defines a stopping criterion from a lambda function. This allows quick definition of stopping criteria for prototyping. Parameters ---------- stopcrit Callable returning whether to stop or not. Examples -------- >>> from probnum import LambdaStoppingCriterion >>> stopcrit = LambdaStoppingCriterion(lambda iters: iters >= 100) >>> stopcrit(101) True """ def __init__(self, stopcrit: Callable[[Any], bool]) -> None: self._stopcrit = stopcrit
[docs] def __call__(self, *args, **kwargs) -> bool: return self._stopcrit(*args, **kwargs)