Skip to content

Commit

Permalink
in_axis -> in_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Mar 12, 2024
1 parent 6b7e608 commit f7fa8d7
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion flowjax/bijections/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ def _flat_params_to_transformer(self, params: Array):
dim = self.dim - self.untransformed_dim
transformer_params = jnp.reshape(params, (dim, -1))
transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
return Vmap(transformer, in_axis=eqx.if_array(0))
return Vmap(transformer, in_axes=eqx.if_array(0))
16 changes: 8 additions & 8 deletions flowjax/bijections/jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,25 +180,25 @@ def __init__(
self,
bijection: AbstractBijection,
*,
in_axis: int | None | Callable = None,
in_axes: int | None | Callable = None,
axis_size: int | None = None,
in_axis_condition: int | None = None,
in_axes_condition: int | None = None,
):
if in_axis is not None and axis_size is not None:
if in_axes is not None and axis_size is not None:
raise ValueError("Cannot specify both in_axis and axis_size.")

if axis_size is None:
if in_axis is None:
if in_axes is None:
raise ValueError("Either axis_size or in_axis must be provided.")
_check_no_unwrappables(in_axis)
_check_no_unwrappables(in_axes)
axis_size = _infer_axis_size_from_params(
wrappers.unwrap(bijection), in_axis
wrappers.unwrap(bijection), in_axes
)

self.in_axes = (in_axis, 0, in_axis_condition)
self.in_axes = (in_axes, 0, in_axes_condition)
self.bijection = bijection
self.axis_size = axis_size
self.cond_shape = self.get_cond_shape(in_axis_condition)
self.cond_shape = self.get_cond_shape(in_axes_condition)

def vmap(self, f: Callable):
return eqx.filter_vmap(f, in_axes=self.in_axes, axis_size=self.axis_size)
Expand Down
2 changes: 1 addition & 1 deletion flowjax/bijections/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _flat_params_to_transformer(self, params: Array):
dim = self.shape[-1]
transformer_params = jnp.reshape(params, (dim, -1))
transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
return Vmap(transformer, in_axis=eqx.if_array(0))
return Vmap(transformer, in_axes=eqx.if_array(0))


def masked_autoregressive_mlp(
Expand Down
2 changes: 1 addition & 1 deletion flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def triangular_spline_flow(
def get_splines():
fn = partial(RationalQuadraticSpline, knots=knots, interval=1)
spline = eqx.filter_vmap(fn, axis_size=dim)()
return Vmap(spline, in_axis=eqx.if_array(0))
return Vmap(spline, in_axes=eqx.if_array(0))

def make_layer(key):
lt_key, perm_key, cond_key = jr.split(key, 3)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"Vmap (broadcast params)": Vmap(Affine(1, 2), axis_size=10),
"Vmap (vectorize params)": Vmap(
eqx.filter_vmap(Affine)(jnp.ones(3)),
in_axis=eqx.if_array(0),
in_axes=eqx.if_array(0),
),
"Reshape (unconditional)": Reshape(Affine(scale=jnp.arange(1, 5)), (2, 2)),
"Reshape (conditional)": Reshape(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_bijections/test_jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_vmap_uneven_init():
bijection = eqx.tree_at(lambda bij: bij.loc, bijection, jnp.arange(3))
in_axis = tree_map(lambda _: None, unwrap(bijection))
in_axis = eqx.tree_at(lambda bij: bij.loc, in_axis, 0, is_leaf=lambda x: x is None)
bijection = Vmap(bijection, in_axis=in_axis)
bijection = Vmap(bijection, in_axes=in_axis)

assert bijection.shape == (3,)
assert bijection.bijection.loc.shape == (3,)
Expand All @@ -29,7 +29,7 @@ def test_vmap_error_with_unwrappable():
bijection = Affine(jnp.zeros(1), jnp.ones(1))
in_axis = tree_map(eqx.is_array, bijection)
with pytest.raises(ValueError, match="unwrappable"):
bijection = Vmap(bijection, in_axis=in_axis)
bijection = Vmap(bijection, in_axes=in_axis)


def test_vmap_condition_only():
Expand All @@ -46,9 +46,9 @@ def test_vmap_condition_only():
ValueError,
match="Either axis_size or in_axis must be provided.",
):
bijection = Vmap(bijection, in_axis_condition=0)
bijection = Vmap(bijection, in_axes_condition=0)

bijection = Vmap(bijection, in_axis_condition=1, axis_size=10)
bijection = Vmap(bijection, in_axes_condition=1, axis_size=10)
assert bijection.shape == (10, 3)
assert bijection.cond_shape == (4, 10)

Expand Down

0 comments on commit f7fa8d7

Please sign in to comment.