Source code for probnum.linalg.solvers.stopping_criteria._posterior_contraction

"""Stopping criterion based on the uncertainty about a quantity of interest."""

import numpy as np

import probnum  # pylint: disable="unused-import"
from probnum.typing import ScalarLike

from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion


class PosteriorContractionStoppingCriterion(LinearSolverStoppingCriterion):
    r"""Posterior contraction stopping criterion.

    Terminate when the uncertainty about the quantity of interest :math:`q` is
    sufficiently small, i.e. if :math:`\sqrt{\operatorname{tr}(\mathbb{Cov}(q))}
    \leq \max(\text{atol}, \text{rtol} \lVert b \rVert_2)`, where :math:`q` is either
    the solution :math:`x`, the system matrix :math:`A` or its inverse :math:`A^{-1}`.

    Parameters
    ----------
    qoi :
        Quantity of interest. One of ``{"x", "A", "Ainv"}``.
    atol :
        Absolute tolerance.
    rtol :
        Relative tolerance.
    """

    def __init__(
        self,
        qoi: str = "x",
        atol: ScalarLike = 10**-5,
        rtol: ScalarLike = 10**-5,
    ):
        self.qoi = qoi
        self.atol = probnum.utils.as_numpy_scalar(atol)
        self.rtol = probnum.utils.as_numpy_scalar(rtol)

[docs] def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> bool: """Check whether the uncertainty about the quantity of interest is smaller than the specified tolerance. Parameters ---------- solver_state : Current state of the linear solver. """ trace_cov_qoi = getattr(solver_state.belief, self.qoi).cov.trace() b_norm = np.linalg.norm(solver_state.problem.b, ord=2) return ( np.abs(trace_cov_qoi) <= self.atol**2 or np.abs(trace_cov_qoi) <= (self.rtol * b_norm) ** 2 )