Skip to content
This repository has been archived by the owner on Jul 14, 2024. It is now read-only.

Commit

Permalink
Merge pull request #56 from probml/kalman-smoother-fix
Browse files Browse the repository at this point in the history
Fix to Kalman smoother and add smoothing test.
  • Loading branch information
murphyk authored May 23, 2022
2 parents ddbb938 + cc6dc87 commit e4ad565
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 32 deletions.
6 changes: 4 additions & 2 deletions jsl/lds/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def sample(self,
"""
Simulate a run of n_sample independent stochastic
linear dynamical systems
Parameters
----------
key: jax.random.PRNGKey
Expand Down Expand Up @@ -396,6 +397,7 @@ def kalman_smoother(params: LDS,
Compute the offline version of the Kalman-Filter, i.e,
the kalman smoother for the hidden state.
Note that we require to independently run the kalman_filter function first
Parameters
----------
params: LDS
Expand All @@ -422,8 +424,8 @@ def kalman_smoother(params: LDS,
smoother_step_run = partial(smoother_step, params=params)
elements = (mu_hist[-2::-1],
Sigma_hist[-2::-1, ...],
mu_cond_hist[1:][::-1, ...],
Sigma_cond_hist[1:][::-1, ...])
mu_cond_hist[::-1, ...],
Sigma_cond_hist[::-1, ...])
initial_state = (mut_giv_T, Sigmat_giv_T, 0)

_, (mu_hist_smooth, Sigma_hist_smooth) = lax.scan(smoother_step_run, initial_state, elements)
Expand Down
63 changes: 33 additions & 30 deletions jsl/lds/kalman_filter_test.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
from jax import random
from jax import numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
from jsl.lds.kalman_filter import LDS, kalman_filter

def tfp_filter(timesteps, A, transition_noise_scale, C, observation_noise_scale, mu0, x_hist):
""" Perform filtering using tensorflow probability """
state_size, _ = A.shape
observation_size, _ = C.shape
transition_noise = tfd.MultivariateNormalDiag(
scale_diag=jnp.ones(state_size) * transition_noise_scale
)
obs_noise = tfd.MultivariateNormalDiag(
scale_diag=jnp.ones(observation_size) * observation_noise_scale
)
prior = tfd.MultivariateNormalDiag(mu0, tf.ones([state_size]))
import tensorflow_probability.substrates.jax.distributions as tfd

from jsl.lds.kalman_filter import LDS, kalman_filter, kalman_smoother


LGSSM = tfd.LinearGaussianStateSpaceModel(
timesteps, A, transition_noise, C, obs_noise, prior
def lds_jsl_to_tfp(num_timesteps, lds):
"""Convert a JSL `LDS` object into a tfp `LinearGaussianStateSpaceModel`.
Args:
num_timesteps: int, number of timesteps.
lds: LDS object.
"""
dynamics_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=lds.Q)
emission_noise_dist = tfd.MultivariateNormalFullCovariance(covariance_matrix=lds.R)
initial_dist = tfd.MultivariateNormalFullCovariance(lds.mu, lds.Sigma)

tfp_lgssm = tfd.LinearGaussianStateSpaceModel(
num_timesteps,
lds.A, dynamics_noise_dist,
lds.C, emission_noise_dist,
initial_dist,
)

_, filtered_means, filtered_covs, _, _, _, _ = LGSSM.forward_filter(x_hist)
return filtered_means.numpy(), filtered_covs.numpy()
return tfp_lgssm


def test_kalman_filter():
key = random.PRNGKey(314)
timesteps = 15
num_timesteps = 15
delta = 1.0

### LDS Parameters ###
Expand All @@ -43,20 +43,23 @@ def test_kalman_filter():
Q = jnp.eye(state_size) * transition_noise_scale
R = jnp.eye(observation_size) * observation_noise_scale


### Prior distribution params ###
mu0 = jnp.array([8, 10]).astype(float)
Sigma0 = jnp.eye(state_size) * 1.0

### Sample data ###
lds_instance = LDS(A, C, Q, R, mu0, Sigma0)
z_hist, x_hist = lds_instance.sample(key, timesteps)
z_hist, x_hist = lds_instance.sample(key, num_timesteps)

JSL_z_filt, JSL_Sigma_filt, _, _ = kalman_filter(lds_instance, x_hist)
tfp_z_filt, tfp_Sigma_filt = tfp_filter(
timesteps, A, transition_noise_scale, C, observation_noise_scale, mu0, x_hist
)
filter_output = kalman_filter(lds_instance, x_hist)
JSL_filtered_means, JSL_filtered_covs, *_ = filter_output
JSL_smoothed_means, JSL_smoothed_covs = kalman_smoother(lds_instance, *filter_output)

assert np.allclose(JSL_z_filt, tfp_z_filt, rtol=1e-2)
assert np.allclose(JSL_Sigma_filt, tfp_Sigma_filt, rtol=1e-2)
tfp_lgssm = lds_jsl_to_tfp(num_timesteps, lds_instance)
_, tfp_filtered_means, tfp_filtered_covs, *_ = tfp_lgssm.forward_filter(x_hist)
tfp_smoothed_means, tfp_smoothed_covs = tfp_lgssm.posterior_marginals(x_hist)

assert np.allclose(JSL_filtered_means, tfp_filtered_means, rtol=1e-2)
assert np.allclose(JSL_filtered_covs, tfp_filtered_covs, rtol=1e-2)
assert np.allclose(JSL_smoothed_means, tfp_smoothed_means, rtol=1e-2)
assert np.allclose(JSL_smoothed_covs, tfp_smoothed_covs, rtol=1e-2)

0 comments on commit e4ad565

Please sign in to comment.