Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tackle Typing and Linting Errors #379

Merged
merged 29 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3f89961
Add ruff ignore F722
gileshd Sep 12, 2024
3ad1fa7
Remove unused imports
gileshd Sep 12, 2024
dc7f675
Convert strings with latex to raw-strings
gileshd Sep 13, 2024
20a1df7
Prepend space in uni-dim jaxtyping hints
gileshd Sep 12, 2024
ca9a75c
Fix jr.PRNGKey type hints
gileshd Sep 12, 2024
75205f8
Rename and change PRNGKey Type
gileshd Sep 12, 2024
42c5c22
Add IntScalar type
gileshd Sep 23, 2024
7da57ef
Minor arg and type changes in utils/utils.py
gileshd Sep 17, 2024
62df3ba
Update HMM[Parameter|Property]Set protocols
gileshd Sep 12, 2024
45134e8
Update type annotations in hmm base classes
gileshd Sep 23, 2024
0d4ca74
Update type annotations in hmm inference code.
gileshd Sep 18, 2024
cfd3bfc
Fix type annotations in hmm parallel inference
gileshd Sep 12, 2024
d0a0e0f
Add further type annotations to hmm transitions class
gileshd Oct 16, 2024
4b2c40b
Add further type annotations to hmm initial base class
gileshd Oct 16, 2024
4c84ddc
Add further type annotations to categorical hmm
gileshd Sep 18, 2024
7e04c9e
Add further type annotations to arhmm
gileshd Sep 20, 2024
b2fffa4
Add further type annotations to linreghmm
gileshd Sep 20, 2024
a45e61b
Add further type annotations to Bernoulli HMM
gileshd Sep 20, 2024
32db62a
Add further type annotations to Gamma HMM
gileshd Sep 20, 2024
81138f8
Add further type annotations to Gaussian HMMs
gileshd Sep 20, 2024
4e36520
Add further type annotations to gmhmms
gileshd Sep 20, 2024
4b7ee53
Add further type annotations to logreg hmm
gileshd Sep 21, 2024
e246606
Add further type annotations to multinomialhmm
gileshd Sep 21, 2024
e19748d
Add further type annotations to poisson hmm
gileshd Sep 23, 2024
c2098cf
Add further type annotations to categorical glm hmm
gileshd Oct 6, 2024
547c610
Fix LinearGaussianSSM.sample type hint
gileshd Sep 12, 2024
58c9127
Change type hints to jaxtyping in slds code
gileshd Sep 13, 2024
2b251bc
Merge branch 'main' into ghd/typing
slinderman Jan 27, 2025
1d8f831
scaling down weights for poisson. apparently an upstream change in ja…
slinderman Jan 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax import jacfwd, vmap, lax
import jax.numpy as jnp
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable

Expand Down Expand Up @@ -83,7 +82,7 @@ def compute_weights_and_sigmas(self, m, P):


def _predict(m, P, f, Q, u, g_ev, g_cov):
"""Predict next mean and covariance under an additive-noise Gaussian filter
r"""Predict next mean and covariance under an additive-noise Gaussian filter

p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
where
Expand Down Expand Up @@ -117,7 +116,7 @@ def _predict(m, P, f, Q, u, g_ev, g_cov):


def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist):
"""Condition a Gaussian potential on a new observation with arbitrary
r"""Condition a Gaussian potential on a new observation with arbitrary
likelihood with given functions for conditional moments and make a
Gaussian approximation.
p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
Expand Down Expand Up @@ -172,7 +171,7 @@ def _step(carry, _):


def _statistical_linear_regression(mu, Sigma, m, S, C):
"""Return moment-matching affine coefficients and approximation noise variance
r"""Return moment-matching affine coefficients and approximation noise variance
given joint moments.

g(x) \approx Ax + b + e where e ~ N(0, Omega)
Expand Down
20 changes: 10 additions & 10 deletions dynamax/generalized_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from dynamax.nonlinear_gaussian_ssm.models import FnStateToState, FnStateAndInputToState
from dynamax.nonlinear_gaussian_ssm.models import FnStateToEmission, FnStateAndInputToEmission

FnStateToEmission2 = Callable[[Float[Array, "state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateToEmission2 = Callable[[Float[Array, " state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, "emission_dim emission_dim"]]

# emission distribution takes a mean vector and covariance matrix and returns a distribution
EmissionDistFn = Callable[ [Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]
EmissionDistFn = Callable[ [Float[Array, " state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]


class ParamsGGSSM(NamedTuple):
Expand All @@ -42,7 +42,7 @@ class ParamsGGSSM(NamedTuple):

"""

initial_mean: Float[Array, "state_dim"]
initial_mean: Float[Array, " state_dim"]
initial_covariance: Float[Array, "state_dim state_dim"]
dynamics_function: Union[FnStateToState, FnStateAndInputToState]
dynamics_covariance: Float[Array, "state_dim state_dim"]
Expand Down Expand Up @@ -97,15 +97,15 @@ def covariates_shape(self):
def initial_distribution(
self,
params: ParamsGGSSM,
inputs: Optional[Float[Array, "input_dim"]]=None
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
return MVN(params.initial_mean, params.initial_covariance)

def transition_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
f = params.dynamics_function
if inputs is None:
Expand All @@ -117,8 +117,8 @@ def transition_distribution(
def emission_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
h = params.emission_mean_function
R = params.emission_cov_function
Expand All @@ -128,4 +128,4 @@ def emission_distribution(
else:
mean = h(state, inputs)
cov = R(state, inputs)
return params.emission_dist(mean, cov)
return params.emission_dist(mean, cov)
4 changes: 2 additions & 2 deletions dynamax/generalized_gaussian_ssm/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_poisson_emission(key, kwargs):
keys = jr.split(key, 3)
state_dim = kwargs['state_dim']
emission_dim = 1 # Univariate Poisson
poisson_weights = jr.normal(keys[0], shape=(emission_dim, state_dim))
poisson_weights = jr.normal(keys[0], shape=(emission_dim, state_dim)) / jnp.sqrt(state_dim)
model = GeneralizedGaussianSSM(state_dim, emission_dim)

# Define model parameters with Poisson emission
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_poisson_emission(key, kwargs):

# Fit model with Gaussian emission
gaussian_marginal_lls = conditional_moments_gaussian_filter(gaussian_params, EKFIntegrals(), emissions).marginal_loglik

# Check that the marginal log-likelihoods under Poisson emission are higher
assert pois_marginal_lls > gaussian_marginal_lls

Expand Down
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
plt.figure()
plt.imshow(jnp.vstack((states[None, :], most_likely_states[None, :])),
aspect="auto", interpolation='none', cmap="Greys")
plt.yticks([0.0, 1.0], ["$z$", "$\hat{z}$"])
plt.yticks([0.0, 1.0], ["$z$", r"$\hat{z}$"])
plt.xlabel("time")
plt.xlim(0, 500)


print("true log prob: ", hmm.marginal_log_prob(true_params, emissions, inputs=inputs))
print("test log prob: ", test_hmm.marginal_log_prob(params, emissions, inputs=inputs))

plt.show()
plt.show()
Loading
Loading