From 3e21ca80aa522a15aee3f307db9d0810e6449420 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 3 Jul 2024 21:28:56 -0400 Subject: [PATCH] Allow FilterByValueBlock to handle one or many values This block previously only accepted a single value to filter on. This update makes it handle a list, as well. In that case, it will ensure that the filter matches one of the values in the list. This is an updated implementation of the feature originally proposed in #72. Signed-off-by: Russell Bryant --- src/instructlab/sdg/filterblock.py | 19 +++++++++++++++++-- tests/test_filterblock.py | 13 +++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/filterblock.py index 794a4071..f5551b02 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/filterblock.py @@ -13,8 +13,21 @@ class FilterByValueBlock(Block): def __init__( self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs ) -> None: + """ + Initializes a new instance of the FilterByValueBlock class. + + Parameters: + - filter_column (str): The name of the column in the dataset to apply the filter on. + - filter_value (any or list of any): The value(s) to filter by. + - operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset. + - convert_dtype (callable, optional): A function to convert the data type of the filter column before applying the filter. Defaults to None. + - **batch_kwargs: Additional kwargs for batch processing. + + Returns: + None + """ super().__init__(block_name=self.__class__.__name__) - self.value = filter_value + self.value = filter_value if isinstance(filter_value, list) else [filter_value] self.column_name = filter_column self.operation = operation self.convert_dtype = convert_dtype @@ -38,6 +51,8 @@ def generate(self, samples) -> Dataset: ) return samples.filter( - lambda x: self.operation(x[self.column_name], self.value), + lambda x: any( + self.operation(x[self.column_name], value) for value in self.value + ), num_proc=self.num_procs, ) diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py index 9ee47c65..7b8b1ce7 100644 --- a/tests/test_filterblock.py +++ b/tests/test_filterblock.py @@ -18,6 +18,12 @@ def setUp(self): operation=operator.eq, convert_dtype=int, ) + self.block_with_list = FilterByValueBlock( + filter_column="age", + filter_value=[30, 35], + operation=operator.eq, + convert_dtype=int, + ) self.dataset = Dataset.from_dict( {"age": ["25", "30", "35", "forty", "45"]}, features=Features({"age": Value("string")}), @@ -29,3 +35,10 @@ def test_generate_mixed_types(self, mock_logger): self.assertEqual(len(filtered_dataset), 1) self.assertEqual(filtered_dataset["age"], [30]) mock_logger.error.assert_called() + + @patch("instructlab.sdg.filterblock.logger") + def test_generate_mixed_types_multi_value(self, mock_logger): + filtered_dataset = self.block_with_list.generate(self.dataset) + self.assertEqual(len(filtered_dataset), 2) + self.assertEqual(filtered_dataset["age"], [30, 35]) + mock_logger.error.assert_called()