Skip to content

Commit

Permalink
Vertex Description of FunctionArray Inputs (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdparker authored Jan 28, 2025
1 parent 6f0f10e commit 9c97ba1
Show file tree
Hide file tree
Showing 12 changed files with 821 additions and 107 deletions.
1 change: 1 addition & 0 deletions named_arrays/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import random
from . import plt
from . import optimize
from . import regridding
from . import transformations
from . import ndfilters
from . import colorsynth
Expand Down
45 changes: 23 additions & 22 deletions named_arrays/_functions/function_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def array_function_default(

if isinstance(where, na.AbstractArray):
if isinstance(where, na.AbstractFunctionArray):
if np.any(where.inputs != inputs):
if not np.all(where.inputs == inputs):
raise na.InputValueError("`where.inputs` must match `a.inputs`")
inputs_where = outputs_where = where.outputs
elif isinstance(where, (na.AbstractScalar, na.AbstractVectorArray)):
Expand All @@ -62,7 +62,6 @@ def array_function_default(
inputs_where = outputs_where = where

shape = na.shape_broadcasted(a, where)
shape_inputs = na.shape_broadcasted(inputs, inputs_where)
shape_outputs = na.shape_broadcasted(outputs, outputs_where)

axis_normalized = tuple(shape) if axis is None else (axis,) if isinstance(axis, str) else axis
Expand All @@ -74,8 +73,6 @@ def array_function_default(
f"got {axis} for `axis`, but `{shape} for `shape`"
)

shape_base = {ax: shape[ax] for ax in axis_normalized}

kwargs = dict(
keepdims=keepdims,
)
Expand All @@ -100,17 +97,19 @@ def array_function_default(
else:
inputs_result = inputs
else:
inputs = inputs.cell_centers(axis=set(axis_normalized)-set(a.axes_center))
shape_inputs = na.shape_broadcasted(inputs, inputs_where)
inputs_result = np.mean(
a=na.broadcast_to(inputs, na.broadcast_shapes(shape_inputs, shape_base)),
axis=axis_normalized,
a=na.broadcast_to(inputs, shape_inputs),
axis=[ax for ax in shape_inputs if ax in axis_normalized],
out=inputs_out,
keepdims=keepdims,
where=inputs_where,
)

outputs_result = func(
a=na.broadcast_to(outputs, na.broadcast_shapes(shape_outputs, shape_base)),
axis=axis_normalized,
a=na.broadcast_to(outputs, shape_outputs),
axis=[ax for ax in shape_outputs if ax in axis_normalized],
out=outputs_out,
**kwargs,
)
Expand Down Expand Up @@ -175,16 +174,16 @@ def array_function_percentile_like(
inputs_result = inputs
else:
inputs_result = np.mean(
a=na.broadcast_to(inputs, na.broadcast_shapes(shape_inputs, shape_base)),
axis=axis_normalized,
a=na.broadcast_to(inputs, shape_inputs),
axis=[ax for ax in shape_inputs if ax in axis_normalized],
out=inputs_out,
keepdims=keepdims,
)

outputs_result = func(
a=na.broadcast_to(outputs, na.broadcast_shapes(shape_outputs, shape_base)),
a=na.broadcast_to(outputs, shape_outputs),
q=q,
axis=axis_normalized,
axis=[ax for ax in shape_outputs if ax in axis_normalized],
out=outputs_out,
**kwargs,
)
Expand Down Expand Up @@ -318,8 +317,12 @@ def broadcast_to(
array: na.AbstractFunctionArray,
shape: dict[str, int]
) -> na.FunctionArray:

axes_vertex = array.axes_vertex
shape_inputs = {ax: shape[ax]+1 if ax in axes_vertex else shape[ax] for ax in shape}

return array.type_explicit(
inputs=na.broadcast_to(array.inputs, shape=shape),
inputs=na.broadcast_to(array.inputs, shape=shape_inputs),
outputs=na.broadcast_to(array.outputs, shape=shape),
)

Expand All @@ -329,17 +332,10 @@ def tranpose(
a: na.AbstractFunctionArray,
axes: None | Sequence[str] = None
) -> na.FunctionArray:

a = a.explicit
a = a.broadcasted
shape = a.shape

axes_normalized = tuple(reversed(shape) if axes is None else axes)
#
# if not set(axes_normalized).issubset(shape):
# raise ValueError(f"`axes` {axes} not a subset of `a.axes`, {a.axes}")

shape_inputs = a.inputs.shape
shape_outputs = a.outputs.shape

return a.type_explicit(
inputs=np.transpose(
Expand Down Expand Up @@ -400,7 +396,10 @@ def reshape(
newshape: dict[str, int],
) -> na.FunctionArray:

a = np.broadcast_to(a, a.shape)
a = a.broadcasted
for ax in newshape:
if ax in a.axes_vertex or (ax not in a.axes and len(a.axes_vertex) != 0):
raise ValueError(f"Cannot reshape along axes vertex {a.axes_vertex}.")

return a.type_explicit(
inputs=np.reshape(a.inputs, newshape=newshape),
Expand Down Expand Up @@ -481,6 +480,8 @@ def repeat(
repeats: int | na.AbstractScalarArray,
axis: str,
) -> na.FunctionArray:
if axis in a.axes_vertex:
raise ValueError(f"Array cannot be repeated along vertex axis {axis}.")

a = a.broadcasted

Expand Down
11 changes: 11 additions & 0 deletions named_arrays/_functions/function_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def histogram(
"`weights` must be `None` for `AbstractFunctionArray`"
f"inputs, got {type(weights)}."
)

axis_normalized = tuple(a.shape) if axis is None else (axis,) if isinstance(axis, str) else axis
for ax in axis_normalized:
if ax in a.axes_vertex:
raise ValueError("Taking a histogram of a histogram doesn't work right now.")



return na.histogram(
a=a.inputs,
bins=bins,
Expand Down Expand Up @@ -148,6 +156,9 @@ def pcolormesh(
"`XY` must not be specified."
)

if len(C.axes_vertex) == 1:
raise ValueError("Cannot plot single vertex axis with na.pcolormesh")

return na.plt.pcolormesh(
C.inputs,
C=C.outputs,
Expand Down
Loading

0 comments on commit 9c97ba1

Please sign in to comment.