Source code for probnum.randvars._categorical

"""Categorical random variables."""
from typing import Optional

import numpy as np

from probnum.type import RandomStateArgType

from ._random_variable import DiscreteRandomVariable

class Categorical(DiscreteRandomVariable):
    """Categorical random variable.

    probabilities :
        Probabilities of the events.
    support :
        Support of the categorical distribution. Optional. Default is None,
        in which case the support is chosen as :math:`(0, ..., K-1)` where
        :math:`K` is the number of elements in `event_probabilities`.
    random_state :
        Random state of the random variable.

    def __init__(
        probabilities: np.ndarray,
        support: Optional[np.ndarray] = None,
        random_state: Optional[RandomStateArgType] = None,
        # The set of events is names "support" to be aligned with the method
        # DiscreteRandomVariable.in_support().

        num_categories = len(probabilities)
        self._probabilities = np.asarray(probabilities)
        self._support = (
            np.asarray(support) if support is not None else np.arange(num_categories)

        parameters = {
            "support": self._support,
            "probabilities": self._probabilities,
            "num_categories": num_categories,

        def _sample_categorical(size=()):
            """Sample from a categorical distribution.

            While on first sight, one might think that this
            implementation can be replaced by
            `np.random.choice(, size, self.probabilities)`,
            this is not true, because `np.random.choice` cannot handle
            arrays with `ndim > 1`, but `` can be just that.
            This detour via the `mask` avoids this problem.
            mask = np.random.choice(
                np.arange(len(, size=size, p=self.probabilities

        def _pmf_categorical(x):
            """PMF of a categorical distribution.

            This implementation is defense against cryptic warnings such as:
            x = np.asarray(x)
            if x.dtype != self.dtype:
                raise ValueError(
                    "The data type of x does not match with the data type of the support."

            mask = (x ==[0]
            return self.probabilities[mask][0] if len(mask) > 0 else 0.0

        def _mode_categorical():
            mask = np.argmax(self.probabilities)


    def probabilities(self) -> np.ndarray:
        """Event probabilities of the categorical distribution."""
        return self._probabilities

    def support(self) -> np.ndarray:
        """Support of the categorical distribution."""
        return self._support

[docs] def resample(self) -> "Categorical": """Resample the support of the categorical random variable. Return a new categorical random variable (RV), where the support is randomly chosen from the elements in the current support with probabilities given by the current event probabilities. The probabilities of the resulting categorical RV are all equal. """ num_events = len( new_support = self.sample(size=num_events) new_probabilities = np.ones(self.probabilities.shape) / num_events return Categorical( support=new_support, probabilities=new_probabilities, random_state=self.random_state, )