diff --git a/README.md b/README.md index a679ed48..97d27d9d 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ for the case of an HMM with Gaussian emissions. (See [this notebook](https://github.com/probml/dynamax/blob/main/docs/notebooks/hmm/gaussian_hmm.ipynb) for a runnable version of this code.) -``` {.python} +```python import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt @@ -147,7 +147,7 @@ print(post.smoothed_probs.shape) # (1000, 3) JAX allows you to easily vectorize these operations with `vmap`. For example, you can sample and fit to a batch of emissions as shown below. -``` {.python} +```python from functools import partial from jax import vmap