Linear Gaussian filtering and smoothing (discrete)¶
Provided is an example of discrete, linear state-space models on which one can perform Bayesian filtering and smoothing in order to obtain a posterior distribution over a latent state trajectory based on noisy observations. In order to understand the theory behind these methods in detail we refer to [1] and [2].
[1]:
import numpy as np
import probnum as pn
from probnum import filtsmooth, randvars, randprocs
from probnum.problems import TimeSeriesRegressionProblem
[2]:
rng = np.random.default_rng(seed=123)
[3]:
# Make inline plots vector graphics instead of raster graphics
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")
# Plotting
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
plt.style.use("../../probnum.mplstyle")
/tmp/ipykernel_125474/236124620.py:5: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
set_matplotlib_formats("pdf", "svg")
Linear Discrete State-Space Model: Car Tracking¶
We showcase the arguably most simple case in which we consider the following state-space model. Consider matrices \(A \in \mathbb{R}^{d \times d}\) and \(H \in \mathbb{R}^{m \times d}\) where \(d\) is the state dimension and \(m\) is the dimension of the measurements. Then we define the dynamics and the measurement model as follows:
For \(k = 1, \dots, K\) and \(x_0 \sim \mathcal{N}(\mu_0, \Sigma_0)\):
In other words, here, every relationship is linear and every distribution is a Gaussian distribution. Under these simplifying assumptions it is possible to obtain a filtering posterior distribution over the state trajectory \((\boldsymbol{x}_k)_{k=1}^{K}\) by using a Kalman Filter. The example is taken from Example 3.6 in [2].
Define State-Space Model¶
I. Discrete Dynamics Model: Linear, Time-Invariant, Gaussian Transitions¶
[4]:
state_dim = 4
observation_dim = 2
[5]:
delta_t = 0.2
# Define linear transition operator
dynamics_transition_matrix = np.eye(state_dim) + delta_t * np.diag(np.ones(2), 2)
# Define process noise (covariance) matrix
noise_matrix = (
np.diag(np.array([delta_t ** 3 / 3, delta_t ** 3 / 3, delta_t, delta_t]))
+ np.diag(np.array([delta_t ** 2 / 2, delta_t ** 2 / 2]), 2)
+ np.diag(np.array([delta_t ** 2 / 2, delta_t ** 2 / 2]), -2)
)
To create a discrete, LTI Gaussian dynamics model, probnum
provides the LTIGaussian
class.
[6]:
# Create discrete, Linear Time-Invariant Gaussian dynamics model
noise = randvars.Normal(mean=np.zeros(state_dim), cov=noise_matrix)
dynamics_model = randprocs.markov.discrete.LTIGaussian(
transition_matrix=dynamics_transition_matrix,
noise=noise,
)
II. Discrete Measurement Model: Linear, Time-Invariant, Gaussian Measurements¶
[7]:
measurement_marginal_variance = 0.5
measurement_matrix = np.eye(observation_dim, state_dim)
measurement_noise_matrix = measurement_marginal_variance * np.eye(observation_dim)
[8]:
noise = randvars.Normal(mean=np.zeros(observation_dim), cov=measurement_noise_matrix)
measurement_model = randprocs.markov.discrete.LTIGaussian(
transition_matrix=measurement_matrix,
noise=noise,
)
III. Initial State Random Variable¶
[9]:
mu_0 = np.zeros(state_dim)
sigma_0 = 0.5 * measurement_marginal_variance * np.eye(state_dim)
initial_state_rv = randvars.Normal(mean=mu_0, cov=sigma_0)
[10]:
prior_process = randprocs.markov.MarkovSequence(
transition=dynamics_model, initrv=initial_state_rv, initarg=0.0
)
Generate Data for the State-Space Model¶
Next, sample both latent states and noisy observations from the specified state-space model.
[11]:
time_grid = np.arange(0.0, 10.0, step=delta_t)
[12]:
latent_states, observations = randprocs.markov.utils.generate_artificial_measurements(
rng=rng,
prior_process=prior_process,
measmod=measurement_model,
times=time_grid,
)
[13]:
regression_problem = TimeSeriesRegressionProblem(
observations=observations,
locations=time_grid,
measurement_models=[measurement_model] * len(time_grid),
)
Kalman Filtering¶
I. Kalman Filter¶
[14]:
kalman_filter = filtsmooth.gaussian.Kalman(prior_process)
II. Perform Kalman Filtering + Rauch-Tung-Striebel Smoothing¶
[15]:
state_posterior, _ = kalman_filter.filtsmooth(regression_problem)
The method filtsmooth
returns a KalmanPosterior
object which provides convenience functions for e.g. sampling and interpolation. We can also extract the just computed posterior smoothing state variables. This yields a list of Gaussian random variables from which we can extract the statistics in order to visualize them.
[16]:
grid = state_posterior.locations
posterior_state_rvs = (
state_posterior.states
) # List of <num_time_points> Normal Random Variables
posterior_state_means = posterior_state_rvs.mean # Shape: (num_time_points, state_dim)
posterior_state_covs = (
posterior_state_rvs.cov
) # Shape: (num_time_points, state_dim, state_dim)
Visualize Results¶
[17]:
state_fig = plt.figure()
state_fig_gs = gridspec.GridSpec(ncols=2, nrows=2, figure=state_fig)
ax_00 = state_fig.add_subplot(state_fig_gs[0, 0])
ax_01 = state_fig.add_subplot(state_fig_gs[0, 1])
ax_10 = state_fig.add_subplot(state_fig_gs[1, 0])
ax_11 = state_fig.add_subplot(state_fig_gs[1, 1])
# Plot means
mu_x_1, mu_x_2, mu_x_3, mu_x_4 = [posterior_state_means[:, i] for i in range(state_dim)]
ax_00.plot(grid, mu_x_1, label="posterior mean")
ax_01.plot(grid, mu_x_2)
ax_10.plot(grid, mu_x_3)
ax_11.plot(grid, mu_x_4)
# Plot marginal standard deviations
std_x_1, std_x_2, std_x_3, std_x_4 = [
np.sqrt(posterior_state_covs[:, i, i]) for i in range(state_dim)
]
ax_00.fill_between(
grid,
mu_x_1 - 1.96 * std_x_1,
mu_x_1 + 1.96 * std_x_1,
alpha=0.2,
label="1.96 marginal stddev",
)
ax_01.fill_between(grid, mu_x_2 - 1.96 * std_x_2, mu_x_2 + 1.96 * std_x_2, alpha=0.2)
ax_10.fill_between(grid, mu_x_3 - 1.96 * std_x_3, mu_x_3 + 1.96 * std_x_3, alpha=0.2)
ax_11.fill_between(grid, mu_x_4 - 1.96 * std_x_4, mu_x_4 + 1.96 * std_x_4, alpha=0.2)
# Plot groundtruth
obs_x_1, obs_x_2 = [observations[:, i] for i in range(observation_dim)]
ax_00.scatter(time_grid, obs_x_1, marker=".", label="measurements")
ax_01.scatter(time_grid, obs_x_2, marker=".")
# Add labels etc.
ax_00.set_xlabel("t")
ax_01.set_xlabel("t")
ax_10.set_xlabel("t")
ax_11.set_xlabel("t")
ax_00.set_title(r"$x_1$")
ax_01.set_title(r"$x_2$")
ax_10.set_title(r"$x_3$")
ax_11.set_title(r"$x_4$")
handles, labels = ax_00.get_legend_handles_labels()
state_fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
state_fig.tight_layout()