Skip to content
This repository has been archived by the owner on Jul 14, 2024. It is now read-only.

Commit

Permalink
Merge pull request #57 from petergchang/patch-1
Browse files Browse the repository at this point in the history
Change scipy.special.logit  into jnp.log
  • Loading branch information
murphyk authored May 24, 2022
2 parents e4ad565 + 311d559 commit 18c78a4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jsl/hmm/hmm_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def draw_state(carry, key):
(t, post_state) = carry

ffbs_dist_t = normalize(trans_mat * alpha[t])[0]
logits = logit(ffbs_dist_t[:, post_state])
logits = jnp.log(ffbs_dist_t[:, post_state])
state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
return (t - 1, state), state

logits = logit(alpha[seq_len - 1])
logits = jnp.log(alpha[seq_len - 1])
final_state = jax.random.categorical(rng_init, logits=logits.flatten(), shape=(1,))
_, states = jax.lax.scan(draw_state, (seq_len - 2, final_state), state_keys)
states = jnp.flip(jnp.append(jnp.array([final_state]), states), axis=0)
Expand Down Expand Up @@ -159,11 +159,11 @@ def hmm_sample_jax(params, seq_len, rng_key):

n_states, n_obs = obs_mat.shape

initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1,))
initial_state = jax.random.categorical(rng_key, logits=jnp.log(init_dist), shape=(1,))
obs_states = jnp.arange(n_obs)

def draw_state(prev_state, key):
logits = logit(trans_mat[prev_state])
logits = jnp.log(trans_mat[prev_state])
state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
return state, state

Expand Down

0 comments on commit 18c78a4

Please sign in to comment.