"""Belief update for conjugate Bayesian quadrature."""

from __future__ import annotations

from typing import Optional, Tuple

import numpy as np

from probnum.quad.kernel_embeddings import KernelEmbedding
from probnum.quad.solvers._bq_state import BQState
from probnum.randprocs.kernels import Kernel
from probnum.randvars import Normal
from probnum.typing import FloatLike

from ._belief_update import BQBeliefUpdate

# pylint: disable=too-few-public-methods
class BQStandardBeliefUpdate(BQBeliefUpdate):
    """Updates integral belief and state using standard Bayesian quadrature based on
    standard Gaussian process inference.

        Non-negative jitter to numerically stabilise kernel matrix inversion.
        Estimation method to use to compute the scale parameter.

    def __init__(self, jitter: FloatLike, scale_estimation: Optional[str]) -> None:
        self.scale_estimation = scale_estimation

    # pylint: disable=too-many-locals
[docs] def __call__( self, bq_state: BQState, new_nodes: np.ndarray, new_fun_evals: np.ndarray, *args, **kwargs, ) -> Tuple[Normal, BQState]: # Update nodes and function evaluations nodes = np.concatenate((bq_state.nodes, new_nodes), axis=0) fun_evals = np.append(bq_state.fun_evals, new_fun_evals) # Estimate intrinsic kernel parameters new_kernel, kernel_was_updated = self._estimate_kernel(bq_state.kernel) new_kernel_embedding = KernelEmbedding(new_kernel, bq_state.measure) # Update gram matrix and kernel mean vector. Recompute everything from # scratch if the kernel was updated or if these are the first nodes. if kernel_was_updated or bq_state.nodes.size == 0: gram = new_kernel.matrix(nodes) kernel_means = new_kernel_embedding.kernel_mean(nodes) else: gram_new_new = new_kernel.matrix(new_nodes) gram_old_new = new_kernel.matrix(new_nodes, bq_state.nodes) gram = np.hstack( ( np.vstack((bq_state.gram, gram_old_new)), np.vstack((gram_old_new.T, gram_new_new)), ) ) kernel_means = np.concatenate( ( bq_state.kernel_means, new_kernel_embedding.kernel_mean(new_nodes), ) ) # Cholesky factorisation of the Gram matrix gram_cho_factor = self.compute_gram_cho_factor(gram) # Estimate scaling parameter new_scale_sq = self._estimate_scale(fun_evals, gram_cho_factor, bq_state) # Integral mean and variance weights = self.gram_cho_solve(gram_cho_factor, kernel_means) integral_mean = weights @ fun_evals initial_integral_variance = new_kernel_embedding.kernel_variance() integral_variance = new_scale_sq * ( initial_integral_variance - weights @ kernel_means ) new_belief = Normal(integral_mean, integral_variance) new_state = BQState.from_new_data( kernel=new_kernel, scale_sq=new_scale_sq, nodes=nodes, fun_evals=fun_evals, integral_belief=new_belief, prev_state=bq_state, gram=gram, gram_cho_factor=gram_cho_factor, kernel_means=kernel_means, ) return new_belief, new_state
# pylint: disable=no-self-use def _estimate_kernel(self, kernel: Kernel) -> Tuple[Kernel, bool]: """Estimate the intrinsic kernel parameters. That is, all parameters except the scale.""" new_kernel = kernel kernel_was_updated = False return new_kernel, kernel_was_updated def _estimate_scale( self, fun_evals: np.ndarray, gram_cho_factor: Tuple[np.ndarray, bool], bq_state: BQState, ) -> FloatLike: """Estimate the scale parameter.""" if self.scale_estimation is None: new_scale_sq = bq_state.scale_sq elif self.scale_estimation == "mle": new_scale_sq = ( fun_evals
[docs] @ self.gram_cho_solve(gram_cho_factor, fun_evals) / fun_evals.shape[0] ) else: raise ValueError(f"Scale estimation ({self.scale_estimation}) is unknown.") return new_scale_sq @staticmethod def predict_integrand( x: np.ndarray, bq_state: BQState ) -> Tuple[np.ndarray, np.ndarray]: predictive_mean = np.zeros(x.shape[0]) # zero mean prior predictive_var = bq_state.kernel(x, x) nevals = bq_state.fun_evals.shape[0] if nevals != 0: kXx = bq_state.kernel.matrix(bq_state.nodes, x) weights = BQStandardBeliefUpdate.gram_cho_solve( bq_state.gram_cho_factor, kXx ) # values (with zero mean prior at evals) predictive_mean += weights.T @ (bq_state.fun_evals - np.zeros(nevals)) # variances predictive_var -= np.sum(weights * kXx, axis=0) return predictive_mean, bq_state.scale_sq * predictive_var