Skip to content

Commit

Permalink
Merge pull request #81 from russellb/filterblock-multi-value
Browse files Browse the repository at this point in the history
Allow FilterByValueBlock to handle one or many values
  • Loading branch information
russellb authored Jul 8, 2024
2 parents d6091ff + 3e21ca8 commit 17148cd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,21 @@ class FilterByValueBlock(Block):
def __init__(
self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs
) -> None:
"""
Initializes a new instance of the FilterByValueBlock class.
Parameters:
- filter_column (str): The name of the column in the dataset to apply the filter on.
- filter_value (any or list of any): The value(s) to filter by.
- operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset.
- convert_dtype (callable, optional): A function to convert the data type of the filter column before applying the filter. Defaults to None.
- **batch_kwargs: Additional kwargs for batch processing.
Returns:
None
"""
super().__init__(block_name=self.__class__.__name__)
self.value = filter_value
self.value = filter_value if isinstance(filter_value, list) else [filter_value]
self.column_name = filter_column
self.operation = operation
self.convert_dtype = convert_dtype
Expand All @@ -38,6 +51,8 @@ def generate(self, samples) -> Dataset:
)

return samples.filter(
lambda x: self.operation(x[self.column_name], self.value),
lambda x: any(
self.operation(x[self.column_name], value) for value in self.value
),
num_proc=self.num_procs,
)
13 changes: 13 additions & 0 deletions tests/test_filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def setUp(self):
operation=operator.eq,
convert_dtype=int,
)
self.block_with_list = FilterByValueBlock(
filter_column="age",
filter_value=[30, 35],
operation=operator.eq,
convert_dtype=int,
)
self.dataset = Dataset.from_dict(
{"age": ["25", "30", "35", "forty", "45"]},
features=Features({"age": Value("string")}),
Expand All @@ -29,3 +35,10 @@ def test_generate_mixed_types(self, mock_logger):
self.assertEqual(len(filtered_dataset), 1)
self.assertEqual(filtered_dataset["age"], [30])
mock_logger.error.assert_called()

@patch("instructlab.sdg.filterblock.logger")
def test_generate_mixed_types_multi_value(self, mock_logger):
filtered_dataset = self.block_with_list.generate(self.dataset)
self.assertEqual(len(filtered_dataset), 2)
self.assertEqual(filtered_dataset["age"], [30, 35])
mock_logger.error.assert_called()

0 comments on commit 17148cd

Please sign in to comment.