Source code for probnum.linalg.solvers.policies._conjugate_gradient

"""Policy returning :math:`A`-conjugate actions."""

from typing import Callable, Iterable, Optional, Tuple

import numpy as np

import probnum  # pylint: disable="unused-import"
from probnum import linops, randvars

from . import _linear_solver_policy


class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy):
    r"""Policy returning :math:`A`-conjugate actions.

    Selects the negative gradient / residual as an initial action
    :math:`s_0 = b - A x_0` and then successively generates :math:`A`-conjugate
    actions, i.e. the actions satisfy :math:`s_i^\top A s_j = 0` iff
    :math:`i \neq j`.

    Parameters
    ----------
    reorthogonalization_fn_residual
        Reorthogonalization function, which takes a vector, an orthogonal basis
        and optionally an inner product and returns a reorthogonalized vector. If
        not `None` the residuals are reorthogonalized before the action is computed.
    reorthogonalization_fn_action
        Reorthogonalization function, which takes a vector, an orthogonal basis
        and optionally an inner product and returns a reorthogonalized vector.
        If not `None` the computed action is reorthogonalized.
    """

    def __init__(
        self,
        reorthogonalization_fn_residual: Optional[
            Callable[
                [np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray
            ]
        ] = None,
        reorthogonalization_fn_action: Optional[
            Callable[
                [np.ndarray, Iterable[np.ndarray], linops.LinearOperator], np.ndarray
            ]
        ] = None,
    ) -> None:
        self._reorthogonalization_fn_residual = reorthogonalization_fn_residual
        self._reorthogonalization_fn_action = reorthogonalization_fn_action

[docs] def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState", rng: Optional[np.random.Generator] = None, ) -> np.ndarray: residual = solver_state.residual if solver_state.step == 0: if self._reorthogonalization_fn_residual is not None: solver_state.cache["reorthogonalized_residuals"].append( solver_state.residual ) return residual else: # Reorthogonalization of the residual if self._reorthogonalization_fn_residual is not None: residual, prev_residual = self._reorthogonalized_residual( solver_state=solver_state ) else: prev_residual = solver_state.residuals[solver_state.step - 1] # A-conjugacy correction (in exact arithmetic) beta = (np.linalg.norm(residual) / np.linalg.norm(prev_residual)) ** 2 action = residual + beta * solver_state.actions[solver_state.step - 1] # Reorthogonalization of the resulting action if self._reorthogonalization_fn_action is not None: action = self._reorthogonalized_action( action=action, solver_state=solver_state ) return action
def _reorthogonalized_residual( self, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> Tuple[np.ndarray, np.ndarray]: """Compute the reorthogonalized residual and its predecessor.""" residual = self._reorthogonalization_fn_residual( v=solver_state.residual, orthogonal_basis=np.asarray( solver_state.cache["reorthogonalized_residuals"] ), inner_product=None, ) solver_state.cache["reorthogonalized_residuals"].append(residual) prev_residual = solver_state.cache["reorthogonalized_residuals"][ solver_state.step - 1 ] return residual, prev_residual def _reorthogonalized_action( self, action: np.ndarray, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> np.ndarray: """Reorthogonalize the computed action.""" if isinstance(solver_state.prior.x, randvars.Normal): inprod_matrix = ( solver_state.problem.A @ solver_state.prior.x.cov @ solver_state.problem.A.T ) elif isinstance(solver_state.prior.x, randvars.Constant): inprod_matrix = solver_state.problem.A orthogonal_basis = np.asarray(solver_state.actions[0 : solver_state.step]) return self._reorthogonalization_fn_action( v=action, orthogonal_basis=orthogonal_basis, inner_product=inprod_matrix, )