Source code for probnum.utils.argutils

import numbers

import numpy as np
import scipy._lib._util

from probnum.type import (
    DTypeArgType,
    RandomStateArgType,
    RandomStateType,
    ScalarArgType,
    ShapeArgType,
    ShapeType,
)

__all__ = ["as_shape", "as_random_state", "as_numpy_scalar"]


[docs]def as_random_state(x: RandomStateArgType) -> RandomStateType: return scipy._lib._util.check_random_state(x)
[docs]def as_shape(x: ShapeArgType) -> ShapeType: if isinstance(x, (int, numbers.Integral, np.integer)): return (int(x),) elif isinstance(x, tuple) and all(isinstance(item, int) for item in x): return x else: try: _ = iter(x) except TypeError as e: raise TypeError( f"The given shape {x} must be an integer or an iterable of integers." ) from e if not all(isinstance(item, (int, numbers.Integral, np.integer)) for item in x): raise TypeError(f"The given shape {x} must only contain integer values.") return tuple(int(item) for item in x)
[docs]def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic: is_scalar = np.isscalar(x) is_scalar_array = isinstance(x, np.ndarray) and x.ndim == 0 if not (is_scalar or is_scalar_array): raise ValueError("The given input is not a scalar.") return np.asarray(x, dtype=dtype)[()]