# Source code for probnum.linalg.solvers.stopping_criteria._posterior_contraction

"""Stopping criterion based on the uncertainty about a quantity of interest."""

import numpy as np

import probnum  # pylint: disable="unused-import"
from probnum.typing import ScalarLike

from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion

class PosteriorContractionStoppingCriterion(LinearSolverStoppingCriterion):
r"""Posterior contraction stopping criterion.

Terminate when the uncertainty about the quantity of interest :math:q is
sufficiently small, i.e. if :math:\sqrt{\operatorname{tr}(\mathbb{Cov}(q))}
\leq \max(\text{atol}, \text{rtol} \lVert b \rVert_2), where :math:q is either
the solution :math:x, the system matrix :math:A or its inverse :math:A^{-1}.

Parameters
----------
qoi :
Quantity of interest. One of {"x", "A", "Ainv"}.
atol :
Absolute tolerance.
rtol :
Relative tolerance.
"""

def __init__(
self,
qoi: str = "x",
atol: ScalarLike = 10**-5,
rtol: ScalarLike = 10**-5,
):
self.qoi = qoi
self.atol = probnum.utils.as_numpy_scalar(atol)
self.rtol = probnum.utils.as_numpy_scalar(rtol)

[docs]    def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> bool:
"""Check whether the uncertainty about the quantity of interest is smaller than
the specified tolerance.

Parameters
----------
solver_state :
Current state of the linear solver.
"""
trace_cov_qoi = getattr(solver_state.belief, self.qoi).cov.trace()
b_norm = np.linalg.norm(solver_state.problem.b, ord=2)

return (
np.abs(trace_cov_qoi) <= self.atol**2
or np.abs(trace_cov_qoi) <= (self.rtol * b_norm) ** 2
)