Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Mar 12, 2024
1 parent f7fa8d7 commit de82680
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
26 changes: 13 additions & 13 deletions flowjax/bijections/jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ class Vmap(AbstractBijection):
Args:
bijection: The bijection to vectorize.
in_axis: Specify which axes of the bijection parameters to vectorise over. It
in_axes: Specify which axes of the bijection parameters to vectorise over. It
should be a PyTree of ``None``, ``int`` with the tree structure being a
prefix of the bijection, or a callable mapping ``Leaf -> Union[None, int]``.
Defaults to None. Note, if the bijection contains unwrappables, then in_axes
should be specified for the unwrapped structure of the bijection.
axis_size: The size of the new axis. This should be left unspecified if in_axis
axis_size: The size of the new axis. This should be left unspecified if in_axes
is provided, as the size can be inferred from the bijection parameters.
Defaults to None.
in_axis_condition: Optionally define an axis of the conditioning variable to
in_axes_condition: Optionally define an axis of the conditioning variable to
vectorize over. Defaults to None.
Example:
Expand All @@ -133,7 +133,7 @@ class Vmap(AbstractBijection):
... lambda: RationalQuadraticSpline(knots=5, interval=2),
... axis_size=10
... )()
>>> bijection = Vmap(bijection, in_axis=eqx.if_array(0))
>>> bijection = Vmap(bijection, in_axes=eqx.if_array(0))
>>> bijection.shape
(10,)
Expand All @@ -153,11 +153,11 @@ class Vmap(AbstractBijection):
>>> from flowjax.wrappers import unwrap
>>> bijection = Affine(jnp.zeros(()), jnp.ones(()))
>>> bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3))
>>> in_axis = tree_map(lambda _: None, unwrap(bijection))
>>> in_axis = eqx.tree_at(
... lambda bij: bij.loc, in_axis, 0, is_leaf=lambda x: x is None
>>> in_axes = tree_map(lambda _: None, unwrap(bijection))
>>> in_axes = eqx.tree_at(
... lambda bij: bij.loc, in_axes, 0, is_leaf=lambda x: x is None
... )
>>> bijection = Vmap(bijection, in_axis=in_axis)
>>> bijection = Vmap(bijection, in_axes=in_axes)
>>> bijection.shape
(3,)
>>> bijection.bijection.loc.shape
Expand Down Expand Up @@ -185,11 +185,11 @@ def __init__(
in_axes_condition: int | None = None,
):
if in_axes is not None and axis_size is not None:
raise ValueError("Cannot specify both in_axis and axis_size.")
raise ValueError("Cannot specify both in_axes and axis_size.")

if axis_size is None:
if in_axes is None:
raise ValueError("Either axis_size or in_axis must be provided.")
raise ValueError("Either axis_size or in_axes must be provided.")
_check_no_unwrappables(in_axes)
axis_size = _infer_axis_size_from_params(
wrappers.unwrap(bijection), in_axes
Expand Down Expand Up @@ -243,8 +243,8 @@ def get_cond_shape(self, cond_ax):
)


def _infer_axis_size_from_params(tree, in_axis):
axes = _resolve_vmapped_axes(tree, in_axis)
def _infer_axis_size_from_params(tree, in_axes):
axes = _resolve_vmapped_axes(tree, in_axes)
axis_sizes = tree_leaves(
tree_map(
lambda leaf, ax: leaf.shape[ax] if ax is not None else None,
Expand All @@ -253,7 +253,7 @@ def _infer_axis_size_from_params(tree, in_axis):
),
)
if len(axis_sizes) == 0:
raise ValueError("in_axis did not map to any leaves to vectorize.")
raise ValueError("in_axes did not map to any leaves to vectorize.")
return axis_sizes[0]


Expand Down
12 changes: 6 additions & 6 deletions tests/test_bijections/test_jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def test_vmap_uneven_init():
"Tests adding a batch dimension to a particular leaf (parameter array)."
bijection = Affine(jnp.zeros(()), jnp.ones(()))
bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3))
in_axis = tree_map(lambda _: None, unwrap(bijection))
in_axis = eqx.tree_at(lambda bij: bij.loc, in_axis, 0, is_leaf=lambda x: x is None)
bijection = Vmap(bijection, in_axes=in_axis)
in_axes = tree_map(lambda _: None, unwrap(bijection))
in_axes = eqx.tree_at(lambda bij: bij.loc, in_axes, 0, is_leaf=lambda x: x is None)
bijection = Vmap(bijection, in_axes=in_axes)

assert bijection.shape == (3,)
assert bijection.bijection.loc.shape == (3,)
Expand All @@ -27,9 +27,9 @@ def test_vmap_uneven_init():

def test_vmap_error_with_unwrappable():
bijection = Affine(jnp.zeros(1), jnp.ones(1))
in_axis = tree_map(eqx.is_array, bijection)
in_axes = tree_map(eqx.is_array, bijection)
with pytest.raises(ValueError, match="unwrappable"):
bijection = Vmap(bijection, in_axes=in_axis)
bijection = Vmap(bijection, in_axes=in_axes)


def test_vmap_condition_only():
Expand All @@ -44,7 +44,7 @@ def test_vmap_condition_only():

with pytest.raises(
ValueError,
match="Either axis_size or in_axis must be provided.",
match="Either axis_size or in_axes must be provided.",
):
bijection = Vmap(bijection, in_axes_condition=0)

Expand Down

0 comments on commit de82680

Please sign in to comment.