"""Posterior over states after applying (Extended/Unscented) Kalman filtering/smoothing.
Contains the discrete time and function outputs. Provides dense output by being
callable. Can function values can also be accessed by indexing.
"""
from __future__ import annotations
import abc
from typing import Iterable, Optional, Union
import numpy as np
from scipy import stats
from probnum import randprocs, randvars, utils
from probnum.filtsmooth import _timeseriesposterior
from probnum.filtsmooth.gaussian import approx
from probnum.typing import ArrayLike, FloatLike, IntLike, ShapeLike
GaussMarkovPriorTransitionArgType = Union[
randprocs.markov.discrete.LinearGaussian,
approx.DiscreteEKFComponent,
approx.DiscreteUKFComponent,
randprocs.markov.continuous.LinearSDE,
approx.ContinuousEKFComponent,
]
"""Any linear(ized) transition can define an (approximate) Gauss-Markov prior."""
class KalmanPosterior(_timeseriesposterior.TimeSeriesPosterior, abc.ABC):
"""Posterior distribution after approximate Gaussian filtering and smoothing.
Parameters
----------
locations :
Locations / Times of the discrete-time estimates.
states :
Estimated states (in the state-space model view) of the discrete-time estimates.
transition :
Dynamics model used as a prior for the filter.
"""
def __init__(
self,
transition: GaussMarkovPriorTransitionArgType,
locations: Optional[Iterable[FloatLike]] = None,
states: Optional[Iterable[randvars.RandomVariable]] = None,
diffusion_model=None,
) -> None:
super().__init__(locations=locations, states=states)
self.transition = transition
self.diffusion_model = diffusion_model
self.diffusion_model_has_been_provided = diffusion_model is not None
[docs] @abc.abstractmethod
def interpolate(
self,
t: FloatLike,
previous_index: Optional[IntLike] = None,
next_index: Optional[IntLike] = None,
) -> randvars.RandomVariable:
raise NotImplementedError
[docs] def sample(
self,
rng: np.random.Generator,
t: Optional[ArrayLike] = None,
size: Optional[ShapeLike] = (),
) -> np.ndarray:
size = utils.as_shape(size)
single_rv_shape = self.states[0].shape
single_rv_ndim = self.states[0].ndim
# Early exit if no dense output is required
if t is None:
base_measure_realizations = stats.norm.rvs(
size=(size + self.locations.shape + single_rv_shape),
random_state=rng,
)
return self.transform_base_measure_realizations(
base_measure_realizations=base_measure_realizations, t=self.locations
)
# Compute the union (as sets) of t and self.locations
# This allows that samples "always pass" the grid points.
all_locations = np.union1d(t, self.locations)
slice_these_out = np.where(np.isin(all_locations, t))[0]
base_measure_realizations = stats.norm.rvs(
size=(size + all_locations.shape + single_rv_shape),
random_state=rng,
)
samples = self.transform_base_measure_realizations(
base_measure_realizations=base_measure_realizations, t=all_locations
)
new_samples = np.take(
samples, indices=slice_these_out, axis=-(single_rv_ndim + 1)
)
return new_samples
[docs] @abc.abstractmethod
def transform_base_measure_realizations(
self,
base_measure_realizations: np.ndarray,
t: ArrayLike,
) -> np.ndarray:
"""Transform samples from a base measure to samples from the KalmanPosterior.
Here, the base measure is a multivariate standard Normal distribution.
Parameters
----------
base_measure_realizations :
**Shape (*size, N, d).**
Samples from a multivariate standard Normal distribution.
`N` is either the `len(self.locations)` (if `t == None`),
or `len(t) + 1` (if `t != None`). The reason for the `+1` in the latter
is that samples at arbitrary locations need to be conditioned on
a sample at the final time point.
t :
**Shape (N,).**
Time points. Must include `self.locations`.Shape
Returns
-------
np.ndarray
**Shape (*size, N, d)**
Transformed base measure realizations. If the inputs are samples
from a multivariate standard Normal distribution, the results are
`size` samples from the Kalman posterior at prescribed locations.
"""
raise NotImplementedError
class SmoothingPosterior(KalmanPosterior):
"""Smoothing posterior.
Parameters
----------
locations : `array_like`
Locations / Times of the discrete-time estimates.
states : :obj:`list` of :obj:`RandomVariable`
Estimated states (in the state-space model view) of the discrete-time estimates.
transition : :obj:`Transition`
Dynamics model used as a prior for the filter.
filtering_posterior :
Filtering posterior.
"""
def __init__(
self,
filtering_posterior: _timeseriesposterior.TimeSeriesPosterior,
transition: GaussMarkovPriorTransitionArgType,
locations: Iterable[FloatLike],
states: Iterable[randvars.RandomVariable],
diffusion_model=None,
):
self.filtering_posterior = filtering_posterior
super().__init__(
transition=transition,
locations=locations,
states=states,
diffusion_model=diffusion_model,
)
[docs] def interpolate(
self,
t: FloatLike,
previous_index: Optional[IntLike] = None,
next_index: Optional[IntLike] = None,
) -> randvars.RandomVariable:
# Assert either previous_location or next_location is not None
# Otherwise, there is no reference point that can be used for interpolation.
if previous_index is None and next_index is None:
raise ValueError
previous_location = (
self.locations[previous_index] if previous_index is not None else None
)
next_location = self.locations[next_index] if next_index is not None else None
previous_state = (
self.states[previous_index] if previous_index is not None else None
)
next_state = self.states[next_index] if next_index is not None else None
# Corner case 1: point is on grid. In this case, don't compute anything.
if t == previous_location:
return previous_state
if t == next_location:
return next_state
# This block avoids calling self.diffusion_model, because we do not want
# to search the full index set -- we already know the index!
# This is the reason that `Diffusion` objects implement a __getitem__.
# The usual diffusion-index is the next index
# ('Diffusion's include the right-hand side gridpoint!),
# but if we are right of the domain, the previous_index matters.
diffusion_index = next_index if next_index is not None else previous_index
if diffusion_index >= len(self.locations) - 1:
diffusion_index = -1
if self.diffusion_model_has_been_provided:
squared_diffusion = self.diffusion_model[diffusion_index]
else:
squared_diffusion = 1.0
# Corner case 2: are extrapolating to the left
if previous_location is None:
raise NotImplementedError("Extrapolation to the left is not implemented.")
# The code below would more or less work,
# but since forward and backward transitions
# cannot handle negative time increments reliably,
# we do not support it.
#
############################################################
#
# dt = t - next_location
# assert dt < 0.0
# extrapolated_rv_left, _ = self.transition.forward_rv(
# next_state, t=next_location, dt=dt, _diffusion=squared_diffusion
# )
# return extrapolated_rv_left
#
############################################################
# Corner case 3: we are extrapolating to the right
if next_location is None:
dt = t - previous_location
assert dt > 0.0
extrapolated_rv_right, _ = self.transition.forward_rv(
previous_state, t=previous_location, dt=dt, _diffusion=squared_diffusion
)
return extrapolated_rv_right
# Final case: we are interpolating. Both locations are not None.
# In this case, filter from the the left to the middle point;
# And compute a smoothing update from the middle to the RHS point.
if np.abs(previous_index - next_index) > 1.1:
raise ValueError
dt_left = t - previous_location
dt_right = next_location - t
assert dt_left > 0.0
assert dt_right > 0.0
filtered_rv, _ = self.transition.forward_rv(
rv=previous_state,
t=previous_location,
dt=dt_left,
_diffusion=squared_diffusion,
)
smoothed_rv, _ = self.transition.backward_rv(
rv_obtained=next_state,
rv=filtered_rv,
t=t,
dt=dt_right,
_diffusion=squared_diffusion,
)
return smoothed_rv
[docs] def transform_base_measure_realizations(
self,
base_measure_realizations: np.ndarray,
t,
) -> np.ndarray:
# Early exit: recursively compute multiple samples
# if the desired sample size is not equal to '()', which is the case if
# the shape of base_measure_realization is not (len(locations), shape(RV))
# t_shape = self.locations.shape if t is None else (len(t) + 1,)
size_zero_shape = () + t.shape + self.states[0].shape
if base_measure_realizations.shape != size_zero_shape:
return np.array(
[
self.transform_base_measure_realizations(
base_measure_realizations=base_real,
t=t,
)
for base_real in base_measure_realizations
]
)
# Now we are in the setting of jointly sampling
# a single realization from the posterior.
# On time points inside the domain,
# this is essentially a sequence of smoothing steps.
t = np.asarray(t) if t is not None else None
if not np.all(np.isin(self.locations, t)):
raise ValueError(
"Base measure realizations cannot be transformed "
"if the locations don't include self.locations."
)
if not np.all(np.diff(t) >= 0.0):
raise ValueError("Time-points have to be sorted.")
# Find locations of the diffusions, which amounts to finding the locations
# of the grid points in t (think: `all_locations`),
# which is done via np.searchsorted:
diffusion_indices = np.searchsorted(self.locations[:-2], t[1:])
if self.diffusion_model_has_been_provided:
squared_diffusion_list = self.diffusion_model[diffusion_indices]
else:
squared_diffusion_list = np.ones_like(t)
# Split into interpolation and extrapolation samples.
# For extrapolation, samples are propagated forwards.
# Due to this distinction, we need to treat both cases differently.
# Note: t=tmax is in two arrays!
# This is on purpose, because sample realisations need to be
# "communicated" between interpolation and extrapolation.
t0, tmax = np.amin(self.locations), np.amax(self.locations)
t_extra_left = t[t < t0]
t_extra_right = t[tmax <= t]
t_inter = t[(t0 <= t) & (t <= tmax)]
if len(t_extra_left) > 0:
raise NotImplementedError(
"Sampling on the left of the time-domain is not implemented."
)
# Split base measure realisations (which have, say, length N + M - 1):
# the first N realizations belong to the interpolation samples,
# and the final M realizations belong to the extrapolation samples.
# Note again: the sample corresponding to tmax belongs to both groups.
base_measure_reals_inter = base_measure_realizations[: len(t_inter)]
base_measure_reals_extra_right = base_measure_realizations[
-len(t_extra_right) :
]
squared_diffusion_list_inter = squared_diffusion_list[: len(t_inter)]
squared_diffusion_list_extra_right = squared_diffusion_list[
-len(t_extra_right) :
]
states = self.filtering_posterior(t)
states_inter = states[: len(t_inter)]
states_extra_right = states[-len(t_extra_right) :]
samples_inter = np.array(
self.transition.jointly_transform_base_measure_realization_list_backward(
base_measure_realizations=base_measure_reals_inter,
t=t_inter,
rv_list=states_inter,
_diffusion_list=squared_diffusion_list_inter,
)
)
samples_extra = np.array(
self.transition.jointly_transform_base_measure_realization_list_forward(
base_measure_realizations=base_measure_reals_extra_right,
t=t_extra_right,
initrv=states_extra_right[0],
_diffusion_list=squared_diffusion_list_extra_right,
)
)
samples = np.concatenate((samples_inter[:-1], samples_extra), axis=0)
return samples
@property
def _states_left_of_location(self):
return self.filtering_posterior._states_left_of_location
class FilteringPosterior(KalmanPosterior):
"""Filtering posterior."""
[docs] def interpolate(
self,
t: FloatLike,
previous_index: Optional[IntLike] = None,
next_index: Optional[IntLike] = None,
) -> randvars.RandomVariable:
# Assert either previous_location or next_location is not None
# Otherwise, there is no reference point that can be used for interpolation.
if previous_index is None and next_index is None:
raise ValueError
previous_location = (
self.locations[previous_index] if previous_index is not None else None
)
next_location = self.locations[next_index] if next_index is not None else None
previous_state = (
self.states[previous_index] if previous_index is not None else None
)
next_state = self.states[next_index] if next_index is not None else None
# Corner case 1: point is on grid
if t == previous_location:
return previous_state
if t == next_location:
return next_state
# Corner case 2: are extrapolating to the left
if previous_location is None:
raise NotImplementedError("Extrapolation to the left is not implemented.")
# The code below would work, but since forward and backward transitions
# cannot handle negative time increments reliably, we do not support it.
#
############################################################
#
# dt = t - next_location
# assert dt < 0.0
# extrapolated_rv_left, _ = self.transition.forward_rv(
# next_state, t=next_location, dt=dt
# )
# return extrapolated_rv_left
#
############################################################
# Final case: we are extrapolating to the right.
# This is also how the filter-posterior interpolates
# (by extrapolating from the leftmost point)
# previous_index is not None
if self.diffusion_model_has_been_provided:
diffusion_index = previous_index
if diffusion_index >= len(self.locations) - 1:
diffusion_index = -1
diffusion = self.diffusion_model[diffusion_index]
else:
diffusion = 1.0
dt_left = t - previous_location
assert dt_left > 0.0
filtered_rv, _ = self.transition.forward_rv(
rv=previous_state, t=previous_location, dt=dt_left, _diffusion=diffusion
)
return filtered_rv
[docs] def sample(
self,
rng: np.random.Generator,
t: Optional[ArrayLike] = None,
size: Optional[ShapeLike] = (),
) -> np.ndarray:
# If this error would not be thrown here,
# trying to sample from a FilteringPosterior
# would call FilteringPosterior.transform_base_measure_realizations
# which is not implemented.
# Since an error thrown by that function instead of one thrown
# by FilteringPosterior.sample
# would likely by hard to parse by a user, we explicitly raise a
# NotImplementedError here.
raise NotImplementedError(
"Sampling from the FilteringPosterior is not implemented."
)
[docs] def transform_base_measure_realizations(
self,
base_measure_realizations: np.ndarray,
t: Optional[ArrayLike] = None,
) -> np.ndarray:
raise NotImplementedError(
"Transforming base measure realizations is not implemented."
)