Skip to content
This repository has been archived by the owner on Jul 14, 2024. It is now read-only.

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karalleyna committed May 18, 2022
1 parent 76d557e commit ddbb938
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 99 deletions.
26 changes: 12 additions & 14 deletions jsl/demos/hmm_casino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import distrax
from distrax import HMM
from jsl.hmm.hmm_lib import HMMJax, hmm_forwards_backwards_jax, hmm_sample_jax, hmm_viterbi_jax
from jax.random import PRNGKey


Expand Down Expand Up @@ -94,14 +93,12 @@ def main():

n_samples = 300
init_state_dist = jnp.array([1, 1]) / 2

# hmm = HMM(A, B, init_state_dist)
params = HMMJax(A, B, init_state_dist)

hmm = HMM(trans_dist=distrax.Categorical(probs=A),
init_dist=distrax.Categorical(probs=init_state_dist),
obs_dist=distrax.Categorical(probs=B))

seed = 314
z_hist, x_hist = hmm.sample(seed=PRNGKey(seed), seq_len=n_samples)
seed = 0
z_hist, x_hist = hmm_sample_jax(params, n_samples, PRNGKey(seed))
# z_hist, x_hist = hmm_sample_numpy(params, n_samples, 314)

z_hist_str = "".join((np.array(z_hist) + 1).astype(str))[:60]
Expand All @@ -113,11 +110,11 @@ def main():

# Do inference
# alpha, _, gamma, loglik = hmm_forwards_backwards_numpy(params, x_hist, len(x_hist))
alpha, beta, gamma, loglik = hmm.forward_backward(x_hist)
alpha, beta, gamma, loglik = hmm_forwards_backwards_jax(params, x_hist, len(x_hist))
print(f"Loglikelihood: {loglik}")

# z_map = hmm_viterbi_numpy(params, x_hist)
z_map = hmm.viterbi(x_hist)
z_map = hmm_viterbi_jax(params, x_hist)

dict_figures = {}

Expand All @@ -138,13 +135,13 @@ def main():
plot_inference(z_map, z_hist, ax, map_estimate=True)
ax.set_ylabel("MAP state")
ax.set_title("Viterbi")
dict_figures["hmm_casino_map"] = fig

file_name = "hmm_casino_params"
dict_figures["hmm_casino_map"] = fig
states, observations = ["Fair Dice", "Loaded Dice"], [str(i + 1) for i in range(B.shape[1])]

AA = hmm.trans_dist.probs
assert np.allclose(A, AA)
#AA = hmm.trans_dist.probs
#assert np.allclose(A, AA)

dotfile = hmm_plot_graphviz(A, B, states, observations)
dotfile_dict = {"hmm_casino_graphviz": dotfile}

Expand All @@ -154,6 +151,7 @@ def main():
if __name__ == "__main__":
from jsl.demos.plot_utils import savefig, savedotfile
figs, dotfile = main()

savefig(figs)
savedotfile(dotfile)
plt.show()
11 changes: 6 additions & 5 deletions jsl/hmm/hmm_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This version is kept for historical purposes.
# Author: Gerardo Duran-Martin (@gerdm), Aleyna Kara (@karalleyna), Kevin Murphy (@murphyk)


from jax import lax
from jax.scipy.special import logit
from functools import partial
Expand Down Expand Up @@ -162,15 +163,15 @@ def hmm_sample_jax(params, seq_len, rng_key):
obs_states = jnp.arange(n_obs)

def draw_state(prev_state, key):
logits = logit(trans_mat[:, prev_state])
logits = logit(trans_mat[prev_state])
state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
return state, state

rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
keys = jax.random.split(rng_state, seq_len - 1)

final_state, states = jax.lax.scan(draw_state, initial_state, keys)
state_seq = jnp.append(jnp.array([initial_state]), states)
state_seq = jnp.append(initial_state, states)

def draw_obs(z, key):
obs = jax.random.choice(key, a=obs_states, p=obs_mat[z])
Expand Down Expand Up @@ -479,9 +480,9 @@ def hmm_viterbi_jax(params, obs_seq, length=None):
if length is None:
length = seq_len

trans_log_probs = jax.nn.log_softmax(jnp.log(params.trans_mat))
init_log_probs = jax.nn.log_softmax(jnp.log(params.init_dist))
obs_mat = jnp.log(params.obs_mat)
trans_log_probs = jax.nn.log_softmax(logit(params.trans_mat))
init_log_probs = jax.nn.log_softmax(logit(params.init_dist))
obs_mat = logit(params.obs_mat)
n_states, *_ = obs_mat.shape

first_log_prob = init_log_probs + obs_mat[:, obs_seq[0]]
Expand Down
Loading

0 comments on commit ddbb938

Please sign in to comment.