Skip to content

Commit

Permalink
Modified the behavior of named_arrays.colorsynth.rgb_and_colorbar()
Browse files Browse the repository at this point in the history
… to allow for the `wavelength` argument to be defined on cell edges as well as cell centers. (#90)
  • Loading branch information
byrdie authored Nov 4, 2024
1 parent 67a862c commit deb1c80
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 37 deletions.
13 changes: 11 additions & 2 deletions named_arrays/_functions/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,10 +721,19 @@ class TestColorsynth(
named_arrays.tests.test_core.AbstractTestAbstractArray.TestNamedArrayFunctions.TestColorsynth,
):

def test_colorbar(self, array: na.AbstractFunctionArray, axis: None | str):
def test_colorbar(
self,
array: na.AbstractFunctionArray,
wavelength: None | na.AbstractArray,
axis: None | str,
):
if isinstance(array.outputs, na.AbstractVectorArray):
return
super().test_colorbar(array=array, axis=axis)
super().test_colorbar(
array=array,
wavelength=wavelength,
axis=axis,
)


@pytest.mark.parametrize("array", _function_arrays())
Expand Down
48 changes: 32 additions & 16 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,22 @@ def colorsynth_rgb(
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

if axis is None:
axes = tuple(set(na.shape(spd)) | set(na.shape(wavelength)))
if len(axes) != 1:
raise ValueError(
f"If `axis` is `None`, the other arguments must have zero"
f"or one axis, got {axes}."
)
axis = axes[0]

if wavelength is not None:
if axis in spd.shape:
if wavelength.shape[axis] == spd.shape[axis] + 1:
below = {axis: slice(None, ~0)}
above = {axis: slice(+1, None)}
wavelength = (wavelength[below] + wavelength[above]) / 2

shape = na.shape_broadcasted(
spd,
wavelength,
Expand All @@ -1155,14 +1171,6 @@ def colorsynth_rgb(
)

axes = tuple(shape)
if axis is None:
if len(axes) != 1:
raise ValueError(
f"If `axis` is `None`, the broadcasted shape of the other"
f"arguments must have exactly one axis, got {shape=}"
)
else:
axis = axes[0]
axis_ndarray = axes.index(axis)

result_ndarray = colorsynth.rgb(
Expand Down Expand Up @@ -1209,6 +1217,22 @@ def colorsynth_colorbar(
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

if axis is None:
axes = tuple(set(na.shape(spd)) | set(na.shape(wavelength)))
if len(axes) != 1:
raise ValueError(
f"If `axis` is `None`, the other arguments must have zero "
f"or one axis, got {axes}."
)
axis = axes[0]

if wavelength is not None:
if axis in spd.shape:
if wavelength.shape[axis] == spd.shape[axis] + 1:
below = {axis: slice(None, ~0)}
above = {axis: slice(+1, None)}
wavelength = (wavelength[below] + wavelength[above]) / 2

shape = na.shape_broadcasted(
spd,
wavelength,
Expand All @@ -1219,14 +1243,6 @@ def colorsynth_colorbar(
)

axes = tuple(shape)
if axis is None:
if len(axes) != 1:
raise ValueError(
f"If `axis` is `None`, the broadcasted shape of the other"
f"arguments must have exactly one axis, got {shape=}"
)
else:
axis = axes[0]
axis_ndarray = axes.index(axis)

intensity, wavelength, rgb = colorsynth.colorbar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,13 +615,14 @@ def _colorsynth_normalize(
shape = spd.shape

if axis is None:
if len(shape) != 1:
axes = tuple(set(na.shape(spd)) | set(na.shape(wavelength)))
if len(axes) != 1:
raise ValueError(
f"If `axis` is `None`, the shape of `array` should have only"
f"one element, got {shape=}"
f"If `axis` is `None`, the other arguments must have zero "
f"or one axis, got {axes}."
)
else:
axis = next(iter(shape))
axis = axes[0]

if wavelength is None:
wavelength = na.linspace(0, 1, axis=axis, num=shape[axis])
Expand Down
7 changes: 6 additions & 1 deletion named_arrays/_vectors/tests/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,12 @@ class TestColorsynth(
):

@pytest.mark.skip
def test_colorbar(self, array: na.AbstractArray, axis: None | str):
def test_colorbar(
self,
array: na.AbstractArray,
wavelength: None | na.AbstractScalar,
axis: None | str,
):
pass # pragma: nocover


Expand Down
50 changes: 36 additions & 14 deletions named_arrays/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,56 +1524,78 @@ def test_ndfilter(
assert result.shape == array.shape

@pytest.mark.parametrize("axis", [None, "y"])
@pytest.mark.parametrize(
argnames="wavelength",
argvalues=[
None,
na.linspace(-1, 1, axis="y", num=num_y + 1)
]
)
class TestColorsynth:
def test_rgb(self, array: na.AbstractArray, axis: None | str):
def test_rgb(
self,
array: na.AbstractArray,
wavelength: None | na.AbstractScalar,
axis: None | str,
):
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=RuntimeWarning)
if axis is None:
if array.ndim != 1:
if len(set(na.shape(array)) | set(na.shape(wavelength))) != 1:
with pytest.raises(ValueError):
na.colorsynth.rgb(array, axis=axis)
na.colorsynth.rgb(array, wavelength, axis=axis)
return
else:
result = na.colorsynth.rgb(array, axis=axis)
result = na.colorsynth.rgb(array, wavelength, axis=axis)
assert result.size == 3
else:
if array.shape:
result = na.colorsynth.rgb(array, axis=axis)
result = na.colorsynth.rgb(array, wavelength, axis=axis)
assert result.shape[axis] == 3

def test_colorbar(self, array: na.AbstractArray, axis: None | str):
def test_colorbar(
self,
array: na.AbstractArray,
wavelength: None | na.AbstractScalar,
axis: None | str,
):
if axis is None:
if array.ndim != 1:
if len(set(na.shape(array)) | set(na.shape(wavelength))) != 1:
with pytest.raises(ValueError):
na.colorsynth.colorbar(array, axis=axis)
na.colorsynth.colorbar(array, wavelength, axis=axis)
return

if array.shape:
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=RuntimeWarning)
result = na.colorsynth.colorbar(array, axis=axis)
result = na.colorsynth.colorbar(array, wavelength, axis=axis)
assert isinstance(result, na.FunctionArray)
assert isinstance(result.inputs, na.Cartesian2dVectorArray)
assert isinstance(result.outputs, na.AbstractArray)

def test_rgb_and_colorbar(self, array: na.AbstractArray, axis: None | str):
def test_rgb_and_colorbar(
self,
array: na.AbstractArray,
wavelength: None | na.AbstractScalar,
axis: None | str,
):
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=RuntimeWarning)

if not array.shape:
return

if axis is None:
if array.ndim != 1:
if len(set(na.shape(array)) | set(na.shape(wavelength))) != 1:
return

try:
rgb_expected = na.colorsynth.rgb(array, axis=axis)
colorbar_expected = na.colorsynth.colorbar(array, axis=axis)
rgb_expected = na.colorsynth.rgb(array, wavelength, axis=axis)
colorbar_expected = na.colorsynth.colorbar(array, wavelength, axis=axis)
except TypeError:
return

rgb, colorbar = na.colorsynth.rgb_and_colorbar(array, axis=axis)
rgb, colorbar = na.colorsynth.rgb_and_colorbar(array, wavelength, axis=axis)

assert np.allclose(rgb, rgb_expected, equal_nan=True)
assert np.allclose(colorbar, colorbar_expected, equal_nan=True)
Expand Down

0 comments on commit deb1c80

Please sign in to comment.