Source code for probnum.utils.argutils

"""Utility functions for argument types."""

import numbers

import numpy as np
import scipy._lib._util

from probnum.type import (
    DTypeArgType,
    RandomStateArgType,
    RandomStateType,
    ScalarArgType,
    ShapeArgType,
    ShapeType,
)

__all__ = ["as_shape", "as_random_state", "as_numpy_scalar"]


[docs]def as_random_state(seed: RandomStateArgType) -> RandomStateType: """Turn ``seed`` into a np.random.RandomState instance. Parameters ---------- seed If seed is None, return the RandomState singleton used by np.random. If seed is an int, return a new RandomState instance seeded with seed. If seed is already a RandomState instance, return it. Raises ------- ValueError If seed is neither None, an int or a RandomState instance. """ return scipy._lib._util.check_random_state(seed)
[docs]def as_shape(x: ShapeArgType) -> ShapeType: """Convert a shape representation into a shape defined as a tuple of ints. Parameters ---------- x Shape representation. """ if isinstance(x, (int, numbers.Integral, np.integer)): return (int(x),) elif isinstance(x, tuple) and all(isinstance(item, int) for item in x): return x else: try: _ = iter(x) except TypeError as e: raise TypeError( f"The given shape {x} must be an integer or an iterable of integers." ) from e if not all(isinstance(item, (int, numbers.Integral, np.integer)) for item in x): raise TypeError(f"The given shape {x} must only contain integer values.") return tuple(int(item) for item in x)
[docs]def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic: """Convert a scalar into a NumPy scalar. Parameters ---------- x Scalar value. dtype Data type of the scalar. """ if np.ndim(x) != 0: raise ValueError("The given input is not a scalar.") return np.asarray(x, dtype=dtype)[()]