From daebd9475f87c7f1f447653f1480599d5a7da984 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 5 Nov 2024 17:36:14 -0700 Subject: [PATCH] Added `named_arrays.AbstractArray.cell_centers()`, a method to convert from cell vertices to cell centers. (#93) --- named_arrays/_core.py | 30 ++++++++++++++++++++++++++++ named_arrays/_functions/functions.py | 10 ++++++++++ named_arrays/tests/test_core.py | 27 +++++++++++++++++++++++++ 3 files changed, 67 insertions(+) diff --git a/named_arrays/_core.py b/named_arrays/_core.py index 3dba744..95ed091 100644 --- a/named_arrays/_core.py +++ b/named_arrays/_core.py @@ -455,6 +455,36 @@ def combine_axes( Array with the specified axes combined """ + def cell_centers( + self, + axis: None | str | Sequence[str] = None, + ) -> na.AbstractExplicitArray: + """ + Convert an array from cell vertices to cell centers. + + Parameters + ---------- + axis + The axes of the array to average over. + """ + + if axis is None: + axis = self.axes + elif isinstance(axis, str): + axis = (axis, ) + + result = self.explicit + + shape = result.shape + + for a in axis: + if a in shape: + lower = {a: slice(None, ~0)} + upper = {a: slice(+1, None)} + result = (result[lower] + result[upper]) / 2 + + return result + def volume_cell(self, axis: None | str | Sequence[str]) -> na.AbstractScalar: """ Computes the n-dimensional volume of each cell formed by interpreting diff --git a/named_arrays/_functions/functions.py b/named_arrays/_functions/functions.py index 1906382..a2dc5b1 100644 --- a/named_arrays/_functions/functions.py +++ b/named_arrays/_functions/functions.py @@ -128,6 +128,16 @@ def combine_axes( outputs=outputs.combine_axes(axes=axes, axis_new=axis_new), ) + def cell_centers( + self, + axis: None | str | Sequence[str] = None, + ) -> na.AbstractExplicitArray: + return dataclasses.replace( + self, + inputs=self.inputs.cell_centers(axis), + outputs=self.outputs.cell_centers(axis), + ) + def to_string_array( self, format_value: str = "%.2f", diff --git a/named_arrays/tests/test_core.py b/named_arrays/tests/test_core.py index f2e2b53..ae40da9 100644 --- a/named_arrays/tests/test_core.py +++ b/named_arrays/tests/test_core.py @@ -257,6 +257,33 @@ def test_combine_axes( with pytest.raises(ValueError): array.combine_axes(axes=axes, axis_new=axis_new) + @pytest.mark.parametrize( + argnames="axis", + argvalues=[ + None, + "y", + ("y",), + ("x", "y"), + ] + ) + def test_cell_centers( + self, + array: na.AbstractArray, + axis: None | str | Sequence[str], + ): + if axis is None: + axis_normalized = array.axes + elif isinstance(axis, str): + axis_normalized = (axis, ) + else: + axis_normalized = axis + + result = array.cell_centers(axis) + + for a in axis_normalized: + if a in array.shape: + assert result.shape[a] == array.shape[a] - 1 + @pytest.mark.parametrize( argnames="axis", argvalues=[