Source code for probnum.linalg.solvers.belief_updates.matrix_based._matrix_based_linear_belief_update

"""Belief update in a matrix-based inference view where the information is given by
matrix-vector multiplication."""
import numpy as np

import probnum  # pylint: disable="unused-import"
from probnum import linops, randvars
from probnum.linalg.solvers.beliefs import LinearSystemBelief

from .._linear_solver_belief_update import LinearSolverBeliefUpdate


class MatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate):
    r"""Gaussian belief update in a matrix-based inference framework assuming linear information.

    Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given matrix-variate Gaussian beliefs with Kronecker covariance structure and linear observations :math:`y=As`. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, V \otimes W_{i+1})`, [1]_ [2]_ such that

    .. math ::
        \begin{align}
            M_{i+1} &= M_i + (y - M_i s) (s^\top W_i s)^\dagger s^\top W_i,\\
            W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i.
        \end{align}


    References
    ----------
    .. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on
       Optimization*, 2015, 25, 234-260
    .. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning,
       *Advances in Neural Information Processing Systems (NeurIPS)*, 2020
    """

[docs] def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> LinearSystemBelief: # Inference for A A = self._matrix_based_update( matrix=solver_state.belief.A, action=solver_state.action, observ=solver_state.observation, ) # Inference for Ainv (interpret action and observation as swapped) Ainv = self._matrix_based_update( matrix=solver_state.belief.Ainv, action=solver_state.observation, observ=solver_state.action, ) if solver_state.belief.b is None: b = randvars.Constant(solver_state.problem.b) else: b = solver_state.belief.b return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=b)
def _matrix_based_update( self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray ) -> randvars.Normal: """Matrix-based inference update for linear information.""" if not isinstance(matrix.cov, linops.Kronecker): raise ValueError( f"Covariance must have Kronecker structure, but is '{type(matrix.cov).__name__}'." ) pred = matrix.mean @ action resid = observ - pred covfactor_Ms = matrix.cov.B @ action gram = action.T @ covfactor_Ms gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 gain = covfactor_Ms * gram_pinv covfactor_update = linops.aslinop(gain[:, None]) @ linops.aslinop( covfactor_Ms[None, :] ) resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop( gain[None, :] ) # residual and gain are flipped due to matrix vectorization return randvars.Normal( mean=matrix.mean + resid_gain, cov=linops.Kronecker(A=matrix.cov.A, B=matrix.cov.B - covfactor_update), )