diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 93698390f..52c620fe1 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -34,7 +34,7 @@ import numpy as np import jax -from jax import Array, lax, tree_util +from jax import lax, tree_util import jax.numpy as jnp from jax.scipy.special import logsumexp from jax.typing import ArrayLike @@ -581,7 +581,9 @@ def event_shape(self) -> tuple[int, ...]: ... @property def event_dim(self) -> int: ... - def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> Array: ... + def sample( + self, key: ArrayLike, sample_shape: tuple[int, ...] = () + ) -> ArrayLike: ... def log_prob(self, value: ArrayLike) -> ArrayLike: ...