From 649829afc6ab2d92dde92c58561fdc587c408484 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 14 Jan 2025 15:54:11 -0700 Subject: [PATCH 1/2] Modified `named_arrays.AbstractFunctionArray.__getitem__()` to support any instance of `named_arrays.AbstractArray`. --- named_arrays/_functions/functions.py | 14 ++++++++------ named_arrays/_functions/tests/test_functions.py | 17 ++++++++++------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/named_arrays/_functions/functions.py b/named_arrays/_functions/functions.py index 67f0bab..686f483 100644 --- a/named_arrays/_functions/functions.py +++ b/named_arrays/_functions/functions.py @@ -174,12 +174,14 @@ def _getitem( shape_inputs = inputs.shape shape_outputs = outputs.shape - if isinstance(item, na.AbstractFunctionArray): - if np.any(item.inputs != self.inputs): - raise ValueError("boolean advanced index does not have the same inputs as the array") - - item_inputs = item.outputs - item_outputs = item.outputs + if isinstance(item, na.AbstractArray): + if isinstance(item, na.AbstractFunctionArray): + if np.any(item.inputs != self.inputs): + raise ValueError("boolean advanced index does not have the same inputs as the array") + item_inputs = item.outputs + item_outputs = item.outputs + else: + item_inputs = item_outputs = item shape_item_inputs = item_inputs.shape shape_item_outputs = item_outputs.shape diff --git a/named_arrays/_functions/tests/test_functions.py b/named_arrays/_functions/tests/test_functions.py index ac77164..fed4264 100644 --- a/named_arrays/_functions/tests/test_functions.py +++ b/named_arrays/_functions/tests/test_functions.py @@ -131,6 +131,7 @@ def test_length(self, array: na.AbstractFunctionArray): outputs=na.ScalarArrayRange(0, 2, axis='y'), ) ), + na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5, na.FunctionArray( inputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y), outputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5, @@ -157,13 +158,15 @@ def test__getitem__( array[item] return - if isinstance(item, na.AbstractFunctionArray): - if np.any(item.inputs != array.inputs): - with pytest.raises(ValueError): - array[item] - return - - item_outputs = item_inputs = item.outputs + if isinstance(item, na.AbstractArray): + if isinstance(item, na.AbstractFunctionArray): + if np.any(item.inputs != array.inputs): + with pytest.raises(ValueError): + array[item] + return + item_outputs = item_inputs = item.outputs + else: + item_outputs = item_inputs = item elif isinstance(item, dict): item_inputs = { From ab64f7f5794e985261fde02873815136fc28ff53 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 14 Jan 2025 16:24:24 -0700 Subject: [PATCH 2/2] coverage --- named_arrays/_functions/tests/test_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/named_arrays/_functions/tests/test_functions.py b/named_arrays/_functions/tests/test_functions.py index fed4264..c74e1b8 100644 --- a/named_arrays/_functions/tests/test_functions.py +++ b/named_arrays/_functions/tests/test_functions.py @@ -136,6 +136,10 @@ def test_length(self, array: na.AbstractFunctionArray): inputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y), outputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5, ), + na.FunctionArray( + inputs=na.ScalarLinearSpace(0, 2, axis='y', num=_num_y), + outputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5, + ), na.FunctionArray( inputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y), outputs=na.UniformUncertainScalarArray(