Source code for probnum.statespace.generate_samples

"""Convenience function(s) for state space models."""

import numpy as np
import scipy.stats

from probnum import utils


[docs]def generate_samples(dynmod, measmod, initrv, times, random_state=None): """Samples true states and observations at pre-determined timesteps "times" for a state space model. Parameters ---------- dynmod : statespace.Transition Transition model describing the prior dynamics. measmod : statespace.Transition Transition model describing the measurement model. initrv : randvars.RandomVariable object Random variable according to initial distribution times : np.ndarray, shape (n,) Timesteps on which the states are to be sampled. random_state : Random state that is used to generate the samples from the latent state. Returns ------- states : np.ndarray; shape (len(times), dynmod.dimension) True states according to dynamic model. obs : np.ndarray; shape (len(times), measmod.dimension) Observations according to measurement model. """ obs = np.zeros((len(times), measmod.output_dim)) random_state = utils.as_random_state(random_state) base_measure_realizations_latent_state = scipy.stats.norm.rvs( size=(times.shape + (measmod.input_dim,)), random_state=random_state ) latent_states = np.array( dynmod.jointly_transform_base_measure_realization_list_forward( base_measure_realizations=base_measure_realizations_latent_state, t=times, initrv=initrv, _diffusion_list=np.ones_like(times[:-1]), ) ) for idx, (state, t) in enumerate(zip(latent_states, times)): measured_rv, _ = measmod.forward_realization(state, t=t) measured_rv.random_state = random_state obs[idx] = measured_rv.sample() return latent_states, obs