diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/filterblock.py index 794a4071..f5551b02 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/filterblock.py @@ -13,8 +13,21 @@ class FilterByValueBlock(Block): def __init__( self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs ) -> None: + """ + Initializes a new instance of the FilterByValueBlock class. + + Parameters: + - filter_column (str): The name of the column in the dataset to apply the filter on. + - filter_value (any or list of any): The value(s) to filter by. + - operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset. + - convert_dtype (callable, optional): A function to convert the data type of the filter column before applying the filter. Defaults to None. + - **batch_kwargs: Additional kwargs for batch processing. + + Returns: + None + """ super().__init__(block_name=self.__class__.__name__) - self.value = filter_value + self.value = filter_value if isinstance(filter_value, list) else [filter_value] self.column_name = filter_column self.operation = operation self.convert_dtype = convert_dtype @@ -38,6 +51,8 @@ def generate(self, samples) -> Dataset: ) return samples.filter( - lambda x: self.operation(x[self.column_name], self.value), + lambda x: any( + self.operation(x[self.column_name], value) for value in self.value + ), num_proc=self.num_procs, ) diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py index 9ee47c65..7b8b1ce7 100644 --- a/tests/test_filterblock.py +++ b/tests/test_filterblock.py @@ -18,6 +18,12 @@ def setUp(self): operation=operator.eq, convert_dtype=int, ) + self.block_with_list = FilterByValueBlock( + filter_column="age", + filter_value=[30, 35], + operation=operator.eq, + convert_dtype=int, + ) self.dataset = Dataset.from_dict( {"age": ["25", "30", "35", "forty", "45"]}, features=Features({"age": Value("string")}), @@ -29,3 +35,10 @@ def test_generate_mixed_types(self, mock_logger): self.assertEqual(len(filtered_dataset), 1) self.assertEqual(filtered_dataset["age"], [30]) mock_logger.error.assert_called() + + @patch("instructlab.sdg.filterblock.logger") + def test_generate_mixed_types_multi_value(self, mock_logger): + filtered_dataset = self.block_with_list.generate(self.dataset) + self.assertEqual(len(filtered_dataset), 2) + self.assertEqual(filtered_dataset["age"], [30, 35]) + mock_logger.error.assert_called()