Skip to content

Commit

Permalink
Merge branch 'main' of github.com:probml/ssm-jax
Browse files Browse the repository at this point in the history
  • Loading branch information
slinderman committed Nov 14, 2022
2 parents f9f8b35 + 9e39d1d commit ff4a23e
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 57 deletions.
45 changes: 45 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Contributing

Contributions (pull requests) are very welcome!

## How to contribute

First fork the library on GitHub.

Then clone and install the library in development mode:

```bash
git clone https://github.com/your-username-here/dynamax.git
cd dynamax
pip install -e '.[dev]'
```

Now make your changes. Make sure to include additional tests if necessary.

Next verify the tests all pass:

```bash
pip install pytest
pytest
```

Then push your changes back to your fork of the repository:

```bash
git push
```

Finally, open a pull request on GitHub!

## What to contribute

Please see this [list of open issues](https://github.com/probml/dynamax/issues),
especially ones tagges as "help wanted".



## Contributor License Agreement

Contributions to this project means that the contributors agree to releasing the contributions under the MIT license.


8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

![Test Status](https://github.com/probml/dynamax/actions/workflows/run_tests.yml/badge.svg?branch=main)


***Note: the code is currently under active development, and the API will soon change. Please wait for the official release
on 11/14/22 before using.***

Dynamax is a library for probabilistic state space models (SSMs) written
in [JAX](https://github.com/google/jax). It has code for inference
(state estimation) and learning (parameter estimation) in a variety of
Expand Down Expand Up @@ -169,6 +165,10 @@ params, lls = hmm.fit_em(params, props, batch_emissions, num_iters=20)
These examples demonstrate the dynamax models, but we can also call the low-level
inference code directly.

## Contributing

Please see [this page](https://github.com/probml/dynamax/blob/main/CONTRIBUTING.md) for details
on how to contribute.

## About
Core team: Peter Chang, Giles Harper-Donnelly, Aleyna Kara, Xinglong Li, Scott Linderman, Kevin Murphy.
Expand Down
14 changes: 7 additions & 7 deletions dynamax/generalized_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class ParamsGGSSM(NamedTuple):
The tuple doubles as a container for the ParameterProperties.
$$p(x_t | x_{t-1}, u_t) = N(x_t | f(x_{t-1}, u_t), Q_t)$$
$$p(y_t | x_t) = q(y_t | h(x_t, u_t), R(x_t, u_t))$$
$$p(z_1) = N(x_1 | m, S)$$
$$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$
$$p(y_t | z_t) = q(y_t | h(z_t, u_t), R(z_t, u_t))$$
$$p(z_1) = N(z_1 | m, S)$$
:param initial_mean: $m$
:param initial_covariance: $S$
Expand Down Expand Up @@ -56,13 +56,13 @@ class GeneralizedGaussianSSM(SSM):
The model is defined as follows
$$p(x_t | x_{t-1}, u_t) = N(x_t | f(x_{t-1}, u_t), Q_t)$$
$$p(y_t | x_t) = q(y_t | h(x_t, u_t), R(x_t, u_t))$$
$$p(z_1) = N(x_1 | m, S)$$
$$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$
$$p(y_t | z_t) = q(y_t | h(z_t, u_t), R(z_t, u_t))$$
$$p(z_1) = N(z_1 | m, S)$$
where
* $x_t$ = hidden variables of size `state_dim`,
* $z_t$ = hidden variables of size `state_dim`,
* $y_t$ = observed variables of size `emission_dim`
* $u_t$ = input covariates of size `input_dim` (defaults to 0).
* $f$ = dynamics (transition) function
Expand Down
5 changes: 3 additions & 2 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Optional, Tuple, Union, NamedTuple
from jaxtyping import Int, Float, Array

from dynamax.types import Scalar, PRNGKey

_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x

Expand All @@ -28,7 +29,7 @@ class HMMPosteriorFiltered(NamedTuple):
:param predicted_probs: $p(z_t \mid y_{1:t-1}, \theta)$ for $t=1,\ldots,T$
"""
marginal_loglik: float
marginal_loglik: Scalar
filtered_probs: Float[Array, "num_timesteps num_states"]
predicted_probs: Float[Array, "num_timesteps num_states"]

Expand All @@ -45,7 +46,7 @@ class HMMPosterior(NamedTuple):
:param initial_probs: $p(z_1 \mid y_{1:T}, \theta)$ (also present in `smoothed_probs` but here for convenience)
:param trans_probs: $p(z_t, z_{t+1} \mid y_{1:T}, \theta)$ for $t=1,\ldots,T-1$. (If the transition matrix is fixed, these probabilities may be summed over $t$. See note above.)
"""
marginal_loglik: float
marginal_loglik: Scalar
filtered_probs: Float[Array, "num_timesteps num_states"]
predicted_probs: Float[Array, "num_timesteps num_states"]
smoothed_probs: Float[Array, "num_timesteps num_states"]
Expand Down
22 changes: 12 additions & 10 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class ParamsLGSSMInitial(NamedTuple):
"""
mean: Union[Float[Array, "state_dim"], ParameterProperties]
cov: Union[Float[Array, "state_dim state_dim"], ParameterProperties]
# unconstrained parameters are stored as a vector.
cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties]


class ParamsLGSSMDynamics(NamedTuple):
Expand All @@ -39,10 +40,10 @@ class ParamsLGSSMDynamics(NamedTuple):
:param cov: dynamics covariance $Q$
"""
weights: Union[Float[Array, "state_dim state_dim"], ParameterProperties]
bias: Union[Float[Array, "state_dim"], ParameterProperties]
input_weights: Union[Float[Array, "state_dim input_dim"], ParameterProperties]
cov: Union[Float[Array, "state_dim state_dim"], ParameterProperties]
weights: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], ParameterProperties]
bias: Union[Float[Array, "state_dim"], Float[Array, "ntime state_dim"], ParameterProperties]
input_weights: Union[Float[Array, "state_dim input_dim"], Float[Array, "ntime state_dim input_dim"], ParameterProperties]
cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties]


class ParamsLGSSMEmissions(NamedTuple):
Expand All @@ -58,10 +59,11 @@ class ParamsLGSSMEmissions(NamedTuple):
:param cov: emission covariance $R$
"""
weights: Union[Float[Array, "emission_dim state_dim"], ParameterProperties]
bias: Union[Float[Array, "emission_dim"], ParameterProperties]
input_weights: Union[Float[Array, "emission_dim input_dim"], ParameterProperties]
cov: Union[Float[Array, "emission_dim emission_dim"], ParameterProperties]
weights: Union[Float[Array, "emission_dim state_dim"], Float[Array, "ntime emission_dim state_dim"], ParameterProperties]
bias: Union[Float[Array, "emission_dim"], Float[Array, "ntime emission_dim"], ParameterProperties]
input_weights: Union[Float[Array, "emission_dim input_dim"], Float[Array, "ntime emission_dim input_dim"], ParameterProperties]
cov: Union[Float[Array, "emission_dim emission_dim"], Float[Array, "ntime emission_dim emission_dim"], Float[Array, "emission_dim_triu"], ParameterProperties]



class ParamsLGSSM(NamedTuple):
Expand Down Expand Up @@ -106,7 +108,7 @@ class PosteriorGSSMSmoothed(NamedTuple):
filtered_covariances: Float[Array, "ntime state_dim state_dim"]
smoothed_means: Float[Array, "ntime state_dim"]
smoothed_covariances: Float[Array, "ntime state_dim state_dim"]
smoothed_cross_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None
smoothed_cross_covariances: Optional[Float[Array, "ntime_minus1 state_dim state_dim"]] = None


# Helper functions
Expand Down
9 changes: 6 additions & 3 deletions dynamax/linear_gaussian_ssm/info_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ class TestInfoKFLinReg:

n_obs = 21
x = jnp.linspace(0, 20, n_obs)
X = jnp.column_stack((jnp.ones_like(x), x)) # Design matrix.
X = jnp.column_stack((jnp.ones_like(x), x)) # Design matrix. (N,2)
state_dim = X.shape[1] # 2
emission_dim = 1
F = jnp.eye(2)
Q = jnp.zeros((2, 2)) # No parameter drift.
Q_prec = jnp.diag(jnp.repeat(1e32, 2)) # Can't use infinite precision.
Expand All @@ -135,11 +137,12 @@ class TestInfoKFLinReg:
-2.264 , -0.4508, 1.1672, 6.6524, 4.1452, 5.2677,
6.3403, 9.6264, 14.7842])
inputs = jnp.zeros((len(y), 1))
input_dim = inputs.shape[1]

lgssm_moment = ParamsLGSSM(
initial=ParamsLGSSMInitial(mean=mu0,cov=Sigma0),
dynamics=ParamsLGSSMDynamics(weights=F, bias=jnp.zeros(1), input_weights=jnp.zeros((mu0.shape[0], 1)), cov=Q),
emissions=ParamsLGSSMEmissions(weights=X[:, None, :], bias=jnp.zeros(1), input_weights=jnp.zeros(1), cov=R)
dynamics=ParamsLGSSMDynamics(weights=F, bias=jnp.zeros(state_dim), input_weights=jnp.zeros((state_dim, input_dim)), cov=Q),
emissions=ParamsLGSSMEmissions(weights=X[:, None, :], bias=jnp.zeros(emission_dim), input_weights=jnp.zeros((emission_dim, input_dim)), cov=R)
)

lgssm_info = ParamsLGSSMInfo(
Expand Down
23 changes: 13 additions & 10 deletions dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Float, PyTree
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union
from typing_extensions import Protocol

from dynamax.ssm import SSM
Expand All @@ -33,13 +33,13 @@ class LinearGaussianSSM(SSM):
The model is defined as follows
$$p(x_1) = \mathcal{N}(x_1 \mid m, S)$$
$$p(x_t \mid x_{t-1}, u_t) = \mathcal{N}(x_t \mid F_t x_{t-1} + B_t u_t + b_t, Q_t)$$
$$p(y_t \mid x_t) = \mathcal{N}(y_t \mid H_t x_t + D_t u_t + d_t, R_t)$$
$$p(z_1) = \mathcal{N}(z_1 \mid m, S)$$
$$p(z_t \mid z_{t-1}, u_t) = \mathcal{N}(z_t \mid F_t z_{t-1} + B_t u_t + b_t, Q_t)$$
$$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$
where
* $x_t$ is a latent state of size `state_dim`,
* $z_t$ is a latent state of size `state_dim`,
* $y_t$ is an emission of size `emission_dim`
* $u_t$ is an input of size `input_dim` (defaults to 0)
* $F$ = dynamics (transition) matrix
Expand Down Expand Up @@ -265,9 +265,11 @@ def posterior_predictive(
def e_step(
self,
params: ParamsLGSSM,
emissions: Float[Array, "nseq ntime emission_dim"],
inputs: Optional[Float[Array, "nseq ntime input_dim"]]=None
) -> Tuple[SuffStatsLGSSM, float]:
emissions: Union[Float[Array, "num_timesteps emission_dim"],
Float[Array, "num_batches num_timesteps emission_dim"]],
inputs: Optional[Union[Float[Array, "num_timesteps input_dim"],
Float[Array, "num_batches num_timesteps input_dim"]]]=None,
) -> Tuple[SuffStatsLGSSM, Scalar]:
num_timesteps = emissions.shape[0]
if inputs is None:
inputs = jnp.zeros((num_timesteps, 0))
Expand Down Expand Up @@ -333,7 +335,7 @@ def m_step(
props: ParamsLGSSM,
batch_stats: SuffStatsLGSSM,
m_step_state: Any
) -> ParamsLGSSM:
) -> Tuple[ParamsLGSSM, Any]:

def fit_linear_regression(ExxT, ExyT, EyyT, N):
# Solve a linear regression given sufficient statistics
Expand Down Expand Up @@ -392,7 +394,8 @@ def __init__(self,
has_dynamics_bias=True,
has_emissions_bias=True,
**kw_priors):
super().__init__(state_dim, emission_dim, input_dim, has_dynamics_bias, has_emissions_bias)
super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim,
has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias)

# Initialize prior distributions
def default_prior(arg, default):
Expand Down
12 changes: 6 additions & 6 deletions dynamax/nonlinear_gaussian_ssm/inference_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
def _predict(m, P, f, F, Q, u):
r"""Predict next mean and covariance using first-order additive EKF
p(x_{t+1}) = \int N(x_t | m, S) N(x_{t+1} | f(x_t, u), Q)
= N(x_{t+1} | f(m, u), F(m, u) S F(m, u)^T + Q)
p(z_{t+1}) = \int N(z_t | m, S) N(z_{t+1} | f(z_t, u), Q)
= N(z_{t+1} | f(m, u), F(m, u) S F(m, u)^T + Q)
Args:
m (D_hid,): prior mean.
Expand All @@ -41,10 +41,10 @@ def _predict(m, P, f, F, Q, u):
def _condition_on(m, P, h, H, R, u, y, num_iter):
r"""Condition a Gaussian potential on a new observation.
p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
propto p(x_t | y_{1:t-1}, u_{1:t-1}) p(y_t | x_t, u_t)
= N(x_t | m, S) N(y_t | h_t(x_t, u_t), R_t)
= N(x_t | mm, SS)
p(z_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
propto p(z_t | y_{1:t-1}, u_{1:t-1}) p(y_t | z_t, u_t)
= N(z_t | m, S) N(y_t | h_t(z_t, u_t), R_t)
= N(z_t | mm, SS)
where
mm = m + K*(y - yhat) = mu_cond
yhat = h(m, u)
Expand Down
14 changes: 7 additions & 7 deletions dynamax/nonlinear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
class ParamsNLGSSM(NamedTuple):
"""Parameters for a NLGSSM model.
$$p(x_t | x_{t-1}, u_t) = N(x_t | f(x_{t-1}, u_t), Q_t)$$
$$p(y_t | x_t) = N(y_t | h(x_t, u_t), R_t)$$
$$p(x_1) = N(x_1 | m, S)$$
$$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
$$p(z_1) = N(z_1 | m, S)$$
The tuple doubles as a container for the ParameterProperties.
Expand Down Expand Up @@ -50,13 +50,13 @@ class NonlinearGaussianSSM(SSM):
The model is defined as follows
$$p(x_t | x_{t-1}, u_t) = N(x_t | f(x_{t-1}, u_t), Q_t)$$
$$p(y_t | x_t) = N(y_t | h(x_t, u_t), R_t)$$
$$p(x_1) = N(x_1 | m, S)$$
$$p(z_t | z_{t-1}, u_t) = N(z_t | f(z_{t-1}, u_t), Q_t)$$
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
$$p(z_1) = N(z_1 | m, S)$$
where
* $x_t$ = hidden variables of size `state_dim`,
* $z_t$ = hidden variables of size `state_dim`,
* $y_t$ = observed variables of size `emission_dim`
* $u_t$ = input covariates of size `input_dim` (defaults to 0).
* $f$ = dynamics (transition) function
Expand Down
19 changes: 11 additions & 8 deletions dynamax/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class Posterior(Protocol):
"""A :class:`NamedTuple` with parameters stored as :class:`jax.DeviceArray` in the leaf nodes."""
pass

class SuffStatsSSM(Protocol):
"""A :class:`NamedTuple` with sufficient statics stored as :class:`jax.DeviceArray` in the leaf nodes."""
pass

class SSM(ABC):
r"""A base class for state space models. Such models consist of parameters, which
Expand Down Expand Up @@ -91,7 +94,7 @@ def initial_distribution(
inputs: optional inputs $u_t$
Returns:
distribution over initial latent state, $p(x_1 \mid \theta)$.
distribution over initial latent state, $p(z_1 \mid \theta)$.
"""
raise NotImplementedError
Expand All @@ -107,11 +110,11 @@ def transition_distribution(
Args:
params: model parameters $\theta$
state: current latent state $x_t$
state: current latent state $z_t$
inputs: current inputs $u_t$
Returns:
conditional distribution of next latent state $p(x_{t+1} \mid x_t, u_t, \theta)$.
conditional distribution of next latent state $p(z_{t+1} \mid z_t, u_t, \theta)$.
"""
raise NotImplementedError
Expand All @@ -127,11 +130,11 @@ def emission_distribution(
Args:
params: model parameters $\theta$
state: current latent state $x_t$
state: current latent state $z_t$
inputs: current inputs $u_t$
Returns:
conditional distribution of current emission $p(y_t \mid x_t, u_t, \theta)$
conditional distribution of current emission $p(y_t \mid z_t, u_t, \theta)$
"""
raise NotImplementedError
Expand Down Expand Up @@ -173,7 +176,7 @@ def sample(
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> Tuple[Float[Array, "num_timesteps state_dim"],
Float[Array, "num_timesteps emission_dim"]]:
r"""Sample states $x_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$ and (optionally) inputs $u_{1:T}$.
r"""Sample states $z_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$ and (optionally) inputs $u_{1:T}$.
Args:
params: model parameters $\theta$
Expand Down Expand Up @@ -303,7 +306,7 @@ def e_step(
params: ParameterSet,
emissions: Float[Array, "num_timesteps emission_dim"],
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
) -> PyTree:
) -> Tuple[SuffStatsSSM, Scalar]:
r"""Perform an E-step to compute expected sufficient statistics under the posterior, $p(z_{1:T} \mid y_{1:T}, u_{1:T}, \theta)$.
Args:
Expand All @@ -321,7 +324,7 @@ def m_step(
self,
params: ParameterSet,
props: PropertySet,
batch_stats: PyTree,
batch_stats: SuffStatsSSM,
m_step_state: Any
) -> ParameterSet:
r"""Perform an M-step to find parameters that maximize the expected log joint probability.
Expand Down

0 comments on commit ff4a23e

Please sign in to comment.