From ddbb938a30fcec2123e58a3862427833fd651130 Mon Sep 17 00:00:00 2001 From: karalleyna Date: Wed, 18 May 2022 22:01:14 +0000 Subject: [PATCH] Add tests --- jsl/demos/hmm_casino.py | 26 ++-- jsl/hmm/hmm_lib.py | 11 +- jsl/hmm/hmm_lib_test.py | 258 +++++++++++++++++++++++++++------------- 3 files changed, 196 insertions(+), 99 deletions(-) diff --git a/jsl/demos/hmm_casino.py b/jsl/demos/hmm_casino.py index 687e4b5..4d86014 100644 --- a/jsl/demos/hmm_casino.py +++ b/jsl/demos/hmm_casino.py @@ -15,8 +15,7 @@ import numpy as np import jax.numpy as jnp import matplotlib.pyplot as plt -import distrax -from distrax import HMM +from jsl.hmm.hmm_lib import HMMJax, hmm_forwards_backwards_jax, hmm_sample_jax, hmm_viterbi_jax from jax.random import PRNGKey @@ -94,14 +93,12 @@ def main(): n_samples = 300 init_state_dist = jnp.array([1, 1]) / 2 + # hmm = HMM(A, B, init_state_dist) + params = HMMJax(A, B, init_state_dist) - hmm = HMM(trans_dist=distrax.Categorical(probs=A), - init_dist=distrax.Categorical(probs=init_state_dist), - obs_dist=distrax.Categorical(probs=B)) - - seed = 314 - z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples) + seed = 0 + z_hist, x_hist = hmm_sample_jax(params, n_samples, PRNGKey(seed)) # z_hist, x_hist = hmm_sample_numpy(params, n_samples, 314) z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60] @@ -113,11 +110,11 @@ def main(): # Do inference # alpha, _, gamma, loglik = hmm_forwards_backwards_numpy(params, x_hist, len(x_hist)) - alpha, beta, gamma, loglik = hmm.forward_backward(x_hist) + alpha, beta, gamma, loglik = hmm_forwards_backwards_jax(params, x_hist, len(x_hist)) print(f"Loglikelihood: {loglik}") # z_map = hmm_viterbi_numpy(params, x_hist) - z_map = hmm.viterbi(x_hist) + z_map = hmm_viterbi_jax(params, x_hist) dict_figures = {} @@ -138,13 +135,13 @@ def main(): plot_inference(z_map, z_hist, ax, map_estimate=True) ax.set_ylabel("MAP state") ax.set_title("Viterbi") - dict_figures["hmm_casino_map"] = fig - file_name = "hmm_casino_params" + dict_figures["hmm_casino_map"] = fig states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])] - AA = hmm.trans_dist.probs - assert np.allclose(A, AA) + #AA = hmm.trans_dist.probs + #assert np.allclose(A, AA) + dotfile = hmm_plot_graphviz(A, B, states, observations) dotfile_dict = {"hmm_casino_graphviz": dotfile} @@ -154,6 +151,7 @@ def main(): if __name__ == "__main__": from jsl.demos.plot_utils import savefig, savedotfile figs, dotfile = main() + savefig(figs) savedotfile(dotfile) plt.show() \ No newline at end of file diff --git a/jsl/hmm/hmm_lib.py b/jsl/hmm/hmm_lib.py index d0a1e5c..750a324 100644 --- a/jsl/hmm/hmm_lib.py +++ b/jsl/hmm/hmm_lib.py @@ -5,6 +5,7 @@ # This version is kept for historical purposes. # Author: Gerardo Duran-Martin (@gerdm), Aleyna Kara (@karalleyna), Kevin Murphy (@murphyk) + from jax import lax from jax.scipy.special import logit from functools import partial @@ -162,7 +163,7 @@ def hmm_sample_jax(params, seq_len, rng_key): obs_states = jnp.arange(n_obs) def draw_state(prev_state, key): - logits = logit(trans_mat[:, prev_state]) + logits = logit(trans_mat[prev_state]) state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,)) return state, state @@ -170,7 +171,7 @@ def draw_state(prev_state, key): keys = jax.random.split(rng_state, seq_len - 1) final_state, states = jax.lax.scan(draw_state, initial_state, keys) - state_seq = jnp.append(jnp.array([initial_state]), states) + state_seq = jnp.append(initial_state, states) def draw_obs(z, key): obs = jax.random.choice(key, a=obs_states, p=obs_mat[z]) @@ -479,9 +480,9 @@ def hmm_viterbi_jax(params, obs_seq, length=None): if length is None: length = seq_len - trans_log_probs = jax.nn.log_softmax(jnp.log(params.trans_mat)) - init_log_probs = jax.nn.log_softmax(jnp.log(params.init_dist)) - obs_mat = jnp.log(params.obs_mat) + trans_log_probs = jax.nn.log_softmax(logit(params.trans_mat)) + init_log_probs = jax.nn.log_softmax(logit(params.init_dist)) + obs_mat = logit(params.obs_mat) n_states, *_ = obs_mat.shape first_log_prob = init_log_probs + obs_mat[:, obs_seq[0]] diff --git a/jsl/hmm/hmm_lib_test.py b/jsl/hmm/hmm_lib_test.py index ac3d52c..ad13578 100644 --- a/jsl/hmm/hmm_lib_test.py +++ b/jsl/hmm/hmm_lib_test.py @@ -4,128 +4,213 @@ Author : Aleyna Kara (@karalleyna) ''' +import jax.numpy as jnp +from jax import vmap, nn +from jax.random import split, PRNGKey, uniform, normal -import time +import distrax +from distrax import HMM + +import chex -from jax.random import uniform import numpy as np -from jax import vmap -from jax.random import split, PRNGKey -import jax.numpy as jnp +import time from jsl.hmm.hmm_numpy_lib import HMMNumpy, hmm_forwards_backwards_numpy, hmm_loglikelihood_numpy -from jsl.hmm.hmm_lib import HMMJax +from jsl.hmm.hmm_lib import HMMJax, hmm_viterbi_jax from jsl.hmm.hmm_lib import hmm_sample_jax, hmm_forwards_backwards_jax, hmm_loglikelihood_jax from jsl.hmm.hmm_lib import normalize, fixed_lag_smoother import jsl.hmm.hmm_utils as hmm_utils -import distrax -from distrax import HMM +from tensorflow_probability.substrates import jax as tfp + +tfd = tfp.distributions ####### # Test log likelihood def loglikelihood_numpy(params_numpy, batches, lens): - return np.vstack([hmm_loglikelihood_numpy(params_numpy, batch, l) for batch, l in zip(batches, lens)]) + return np.array([hmm_loglikelihood_numpy(params_numpy, batch, l) for batch, l in zip(batches, lens)]) def loglikelihood_jax(params_jax, batches, lens): - return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens) + return vmap(hmm_loglikelihood_jax, in_axes=(None, 0, 0))(params_jax, batches, lens)[:,:, 0] -# state transition matrix -A = jnp.array([ - [0.95, 0.05], - [0.10, 0.90] -]) -# observation matrix -B = jnp.array([ - [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die - [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die -]) +def test_all_hmm_models(): + # state transition matrix + A = jnp.array([ + [0.95, 0.05], + [0.10, 0.90] + ]) -pi = jnp.array([1, 1]) / 2 + # observation matrix + B = jnp.array([ + [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die + [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die + ]) -params_numpy= HMMNumpy(np.array(A), np.array(B), np.array(pi)) -params_jax = HMMJax(A, B, pi) + pi = jnp.array([1, 1]) / 2 -seed = 0 -rng_key = PRNGKey(seed) -rng_key, rng_sample = split(rng_key) + params_numpy= HMMNumpy(np.array(A), np.array(B), np.array(pi)) + params_jax = HMMJax(A, B, pi) -n_obs_seq, batch_size, max_len = 15, 5, 10 + seed = 0 + rng_key = PRNGKey(seed) + rng_key, rng_sample = split(rng_key) -observations, lens = hmm_utils.hmm_sample_n(params_jax, - hmm_sample_jax, - n_obs_seq, max_len, - rng_sample) + n_obs_seq, batch_size, max_len = 15, 5, 10 -observations, lens = hmm_utils.pad_sequences(observations, lens) + observations, lens = hmm_utils.hmm_sample_n(params_jax, + hmm_sample_jax, + n_obs_seq, max_len, + rng_sample) -rng_key, rng_batch = split(rng_key) -batches, lens = hmm_utils.hmm_sample_minibatches(observations, - lens, - batch_size, - rng_batch) + observations, lens = hmm_utils.pad_sequences(observations, lens) -ll_numpy = loglikelihood_numpy(params_numpy, np.array(batches), np.array(lens)) -ll_jax = loglikelihood_jax(params_jax, batches, lens) + rng_key, rng_batch = split(rng_key) + batches, lens = hmm_utils.hmm_sample_minibatches(observations, + lens, + batch_size, + rng_batch) -assert np.allclose(ll_numpy, ll_jax) -print(f'Loglikelihood {ll_numpy}') + ll_numpy = loglikelihood_numpy(params_numpy, np.array(batches), np.array(lens)) + ll_jax = loglikelihood_jax(params_jax, batches, lens) + assert np.allclose(ll_numpy, ll_jax, atol=4) -######## -#Test Inference -seed = 0 -rng_key = PRNGKey(seed) -rng_key, key_A, key_B = split(rng_key, 3) +def test_inference(): + seed = 0 + rng_key = PRNGKey(seed) + rng_key, key_A, key_B = split(rng_key, 3) + + # state transition matrix + n_hidden, n_obs = 100, 10 + A = uniform(key_A, (n_hidden, n_hidden)) + A = A / jnp.sum(A, axis=1) + + # observation matrix + B = uniform(key_B, (n_hidden, n_obs)) + B = B / jnp.sum(B, axis=1).reshape((-1, 1)) + + n_samples = 1000 + init_state_dist = jnp.ones(n_hidden) / n_hidden + + seed = 0 + rng_key = PRNGKey(seed) + + params_numpy = HMMNumpy(A, B, init_state_dist) + params_jax = HMMJax(A, B, init_state_dist) + hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A), + obs_dist=distrax.Categorical(probs=B), + init_dist=distrax.Categorical(probs=init_state_dist)) + + z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key) + + start = time.time() + alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(params_numpy, x_hist, len(x_hist)) + print(f'Time taken by numpy version of forwards backwards : {time.time()-start}s') -# state transition matrix -n_hidden, n_obs = 100, 10 -A = uniform(key_A, (n_hidden, n_hidden)) -A = A / jnp.sum(A, axis=1) + start = time.time() + alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(params_jax, jnp.array(x_hist), len(x_hist)) + print(f'Time taken by JAX version of forwards backwards: {time.time()-start}s') -# observation matrix -B = uniform(key_B, (n_hidden, n_obs)) -B = B / jnp.sum(B, axis=1).reshape((-1, 1)) + start = time.time() + alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(obs_seq=jnp.array(x_hist), + length=len(x_hist)) -n_samples = 1000 -init_state_dist = jnp.ones(n_hidden) / n_hidden + print(f'Time taken by HMM distrax : {time.time()-start}s') -seed = 0 -rng_key = PRNGKey(seed) + assert np.allclose(alphas_np, alphas_jax) + assert np.allclose(loglikelihood_np, loglikelihood_jax) + assert np.allclose(gammas_np, gammas_jax) -params_numpy = HMMNumpy(A, B, init_state_dist) -params_jax = HMMJax(A, B, init_state_dist) -hmm_distrax = HMM(trans_dist=distrax.Categorical(probs=A), - obs_dist=distrax.Categorical(probs=B), - init_dist=distrax.Categorical(probs=init_state_dist)) + assert np.allclose(alphas, alphas_jax, atol=8) + assert np.allclose(loglikelihood, loglikelihood_jax, atol=8) + assert np.allclose(gammas, gammas_jax, atol=8) -z_hist, x_hist = hmm_sample_jax(params_jax, n_samples, rng_key) -start = time.time() -alphas_np, _, gammas_np, loglikelihood_np = hmm_forwards_backwards_numpy(params_numpy, x_hist, len(x_hist)) -print(f'Time taken by numpy version of forwards backwards : {time.time()-start}s') +def _make_models(init_probs, trans_probs, obs_probs, length): + """Build distrax HMM and equivalent TFP HMM.""" + + dx_model = HMMJax( + trans_probs, + obs_probs, + init_probs + ) -start = time.time() -alphas_jax, _, gammas_jax, loglikelihood_jax = hmm_forwards_backwards_jax(params_jax, jnp.array(x_hist), len(x_hist)) -print(f'Time taken by JAX version of forwards backwards: {time.time()-start}s') + tfp_model = tfd.HiddenMarkovModel( + initial_distribution=tfd.Categorical(probs=init_probs), + transition_distribution=tfd.Categorical(probs=trans_probs), + observation_distribution=tfd.Categorical(probs=obs_probs), + num_steps=length, + ) -start = time.time() -alphas, _, gammas, loglikelihood = hmm_distrax.forward_backward(obs_seq=jnp.array(x_hist), - length=len(x_hist)) + return dx_model, tfp_model -print(f'Time taken by HMM distrax : {time.time()-start}s') -assert np.allclose(alphas_np, alphas_jax) -assert np.allclose(loglikelihood_np, loglikelihood_jax) -assert np.allclose(gammas_np, gammas_jax) +def test_sample(length, num_states): + params_fn = obs_dist_name_and_params_fn + + init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1) + trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1) -assert np.allclose(alphas, alphas_jax, 8) -assert np.allclose(loglikelihood, loglikelihood_jax) -assert np.allclose(gammas, gammas_jax, 8) + model, tfp_model = _make_models(init_probs, + trans_mat, + params_fn(num_states), + length) + states, obs = hmm_sample_jax(model, length, PRNGKey(0)) + tfp_obs = tfp_model.sample(seed=PRNGKey(0)) + + chex.assert_shape(states, (length,)) + chex.assert_equal_shape([obs, tfp_obs]) + + +def test_forward_backward(length, num_states): + params_fn = obs_dist_name_and_params_fn + + init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1) + trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1) + + model, tfp_model = _make_models(init_probs, + trans_mat, + params_fn(num_states), + length) + + _, observations = hmm_sample_jax(model, length, PRNGKey(42)) + + alphas, betas, marginals, log_prob = hmm_forwards_backwards_jax(model, + observations) + + tfp_marginal_logits = tfp_model.posterior_marginals(observations).logits + tfp_marginals = nn.softmax(tfp_marginal_logits) + + chex.assert_shape(alphas, (length, num_states)) + chex.assert_shape(betas, (length, num_states)) + chex.assert_shape(marginals, (length, num_states)) + chex.assert_shape(log_prob, (1,)) + np.testing.assert_array_almost_equal(marginals, tfp_marginals, decimal=4) + + +def test_viterbi(length, num_states): + params_fn = obs_dist_name_and_params_fn + + init_probs = nn.softmax(normal(PRNGKey(0), (num_states,)), axis=-1) + trans_mat = nn.softmax(normal(PRNGKey(1), (num_states, num_states)), axis=-1) + + model, tfp_model = _make_models(init_probs, + trans_mat, + params_fn(num_states), + length) + + _, observations = hmm_sample_jax(model, length, PRNGKey(42)) + most_likely_states = hmm_viterbi_jax(model, observations) + tfp_mode = tfp_model.posterior_mode(observations) + chex.assert_shape(most_likely_states, (length,)) + assert np.allclose(most_likely_states, tfp_mode) + +''' ######## #Test Fixed Lag Smoother @@ -142,4 +227,17 @@ def get_fls_result(params, data, win_len, act=None): *_, gammas_fls = get_fls_result(params_jax, jnp.array(x_hist), jnp.array(x_hist).size) -assert np.allclose(gammas_fls, gammas_jax) \ No newline at end of file +assert np.allclose(gammas_fls, gammas_jax) +''' +obs_dist_name_and_params_fn = lambda n: nn.softmax(normal(PRNGKey(0), (n, 7)), axis=-1) + +### Tests +test_all_hmm_models() +test_inference() + +for length, num_states in zip([1, 3], (2, 23)): + test_viterbi(length, num_states) + test_forward_backward(length, num_states) + test_sample(length, num_states) + +