Skip to content

Commit

Permalink
Add further type annotations to categorical hmm
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 18, 2024
1 parent 43f997c commit d3a0eb0
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions dynamax/hidden_markov_model/models/categorical_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class ParamsCategoricalHMM(NamedTuple):
class CategoricalHMMEmissions(HMMEmissions):

def __init__(self,
num_states,
emission_dim,
num_classes,
emission_prior_concentration=1.1):
num_states: int,
emission_dim: int,
num_classes: int,
emission_prior_concentration: Union[float, Float[Array, " num_classes"]]=1.1):
"""_summary_
Args:
Expand All @@ -45,18 +45,22 @@ def __init__(self,
self.emission_prior_concentration = emission_prior_concentration * jnp.ones(num_classes)

@property
def emission_shape(self):
def emission_shape(self) -> Tuple[int]:
return (self.emission_dim,)

def distribution(self, params, state, inputs=None):
def distribution(self, params: ParamsCategoricalHMMEmissions, state: int, inputs=None) -> tfd.Distribution:
return tfd.Independent(
tfd.Categorical(probs=params.probs[state]),
reinterpreted_batch_ndims=1)

def log_prior(self, params):
def log_prior(self, params: ParamsCategoricalHMMEmissions) -> Float:
return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum()

def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None):
def initialize(self,
key:Optional[Array]=jr.PRNGKey(0),
method="prior",
emission_probs:Optional[Float[Array, "num_states emission_dim"]]=None
) -> Tuple[ParamsCategoricalHMMEmissions, ParamsCategoricalHMMEmissions]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
Expand All @@ -77,6 +81,8 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None):
# Initialize the emission probabilities
if emission_probs is None:
if method.lower() == "prior":
if key is None:
raise ValueError("key must not be None when emission_probs is None")
prior = tfd.Dirichlet(self.emission_prior_concentration)
emission_probs_sample = prior.sample(seed=key, sample_shape=(self.num_states, self.emission_dim))
emission_probs = cast(Float[Array, "num_states emission_dim"], emission_probs_sample)
Expand Down

0 comments on commit d3a0eb0

Please sign in to comment.