Source code for probnum.quad.solvers.policies._random_policy

"""Random policy for Bayesian Monte Carlo."""

from typing import Callable

import numpy as np

from probnum.quad.solvers.bq_state import BQState

from ._policy import Policy

# pylint: disable=too-few-public-methods, fixme


[docs]class RandomPolicy(Policy): """Random sampling from an objective. Parameters ---------- sample_func The sample function. Needs to have the following interface: `sample_func(batch_size: int, rng: np.random.Generator)` and return an array of shape (batch_size, n_dim). batch_size Size of batch of nodes when calling the policy once. """ def __init__( self, sample_func: Callable, batch_size: int, rng: np.random.Generator = np.random.default_rng(), ) -> None: super().__init__(batch_size=batch_size) self.sample_func = sample_func self.rng = rng
[docs] def __call__(self, bq_state: BQState) -> np.ndarray: return self.sample_func(self.batch_size, rng=self.rng)