From 74aa0cebfc98400c0a0a24de940b7468f48b6f52 Mon Sep 17 00:00:00 2001 From: Vladimir Trifonov Date: Fri, 3 Jan 2025 21:58:36 -0800 Subject: [PATCH 1/7] changes to MappedExamplesIterable to resolve #7345 --- src/datasets/iterable_dataset.py | 8 ++++---- tests/test_iterable_dataset.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 382b20915bf..3ec930d87cd 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1045,11 +1045,11 @@ def _iter(self): if self.with_indices: function_args.append([current_idx + i for i in range(len(key_examples_list))]) transformed_batch = dict(batch) # this will be updated with the function output - transformed_batch.update(self.function(*function_args, **self.fn_kwargs)) - # then remove the unwanted columns + # remove the unwanted columns if self.remove_columns: for c in self.remove_columns: del transformed_batch[c] + transformed_batch.update(self.function(*function_args, **self.fn_kwargs)) if transformed_batch: first_col = next(iter(transformed_batch)) bad_cols = [ @@ -1088,11 +1088,11 @@ def _iter(self): if self.with_indices: function_args.append(current_idx) transformed_example = dict(example) # this will be updated with the function output - transformed_example.update(self.function(*function_args, **self.fn_kwargs)) - # then we remove the unwanted columns + # remove the unwanted columns if self.remove_columns: for c in self.remove_columns: del transformed_example[c] + transformed_example.update(self.function(*function_args, **self.fn_kwargs)) current_idx += 1 if self._state_dict: self._state_dict["previous_state_example_idx"] += 1 diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 44d3d24fa23..515dd115508 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -545,7 +545,9 @@ def test_mapped_examples_iterable_with_indices(n, func, batched, batch_size): "n, func, batched, batch_size, remove_columns", [ (3, lambda x: {"id+1": x["id"] + 1}, False, None, ["extra_column"]), # just add 1 to the id + (3, lambda x: {"id+1": x["id"] + 1, "extra_column": "bar"}, False, None, ["extra_column"]), # just add 1 to the id and add back a remove column (25, lambda x: {"id+1": [i + 1 for i in x["id"]]}, True, 10, ["extra_column"]), # same with bs=10 + (25, lambda x: {"id+1": [i + 1 for i in x["id"]], "extra_column": ["bar"]*len(x["id"])}, True, 10, ["extra_column"]), # same with bs=10 and add back a remove column ( 50, lambda x: {"foo": ["bar"] * np.random.default_rng(x["id"][0]).integers(0, 10)}, @@ -553,8 +555,17 @@ def test_mapped_examples_iterable_with_indices(n, func, batched, batch_size): 8, ["extra_column", "id"], ), # make a duplicate of each example + ( + 50, + lambda x: (lambda n: {"foo": ["bar"]*n, "extra_column": ["bar"]*n})(np.random.default_rng(x["id"][0]).integers(0, 10)), + True, + 8, + ["extra_column", "id"], + ), # make a duplicate of each example and add back a remove column (5, lambda x: {"id+1": [i + 1 for i in x["id"]]}, True, None, ["extra_column"]), # same with bs=None + (5, lambda x: {"id+1": [i + 1 for i in x["id"]], "extra_column": ["bar"]*len(x["id"])}, True, None, ["extra_column"]), # same with bs=None and add back a remove column (5, lambda x: {"id+1": [i + 1 for i in x["id"]]}, True, -1, ["extra_column"]), # same with bs<=0 + (5, lambda x: {"id+1": [i + 1 for i in x["id"]], "extra_column": ["bar"]*len(x["id"])}, True, -1, ["extra_column"]) # same with bs<=0 and add back a remove column ], ) def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, remove_columns): From c952c2e34d29df1fd4b6668253732951494d2de1 Mon Sep 17 00:00:00 2001 From: Vladimir Trifonov Date: Sat, 4 Jan 2025 14:17:09 -0800 Subject: [PATCH 2/7] changed MappedExamplesIterable and added test_iterable_dataset_vs_dataset --- src/datasets/iterable_dataset.py | 26 +++++++++------ tests/test_iterable_dataset.py | 55 +++++++++++++++++++++++++------- 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 3ec930d87cd..6bff76e5695 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -79,7 +79,7 @@ def _examples_to_batch(examples: List[Dict[str, Any]]) -> Dict[str, list]: def _batch_to_examples(batch: Dict[str, list]) -> Iterator[Dict[str, Any]]: """Convert a batch (dict of examples) to examples list""" - n_examples = len(batch[next(iter(batch))]) + n_examples = 0 if len(batch)==0 else len(batch[next(iter(batch))]) for i in range(n_examples): yield {col: array[i] for col, array in batch.items()} @@ -1044,12 +1044,16 @@ def _iter(self): function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] if self.with_indices: function_args.append([current_idx + i for i in range(len(key_examples_list))]) - transformed_batch = dict(batch) # this will be updated with the function output - # remove the unwanted columns + inputs_to_merge = dict(batch) + processed_inputs = self.function(*function_args, **self.fn_kwargs) + # this logic mimics the one in Dataset.map if self.remove_columns: for c in self.remove_columns: - del transformed_batch[c] - transformed_batch.update(self.function(*function_args, **self.fn_kwargs)) + if c in inputs_to_merge: + del inputs_to_merge[c] + if processed_inputs == inputs and c in processed_inputs: + del processed_inputs[c] + transformed_batch = {**inputs_to_merge, **processed_inputs} if transformed_batch: first_col = next(iter(transformed_batch)) bad_cols = [ @@ -1087,12 +1091,16 @@ def _iter(self): function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] if self.with_indices: function_args.append(current_idx) - transformed_example = dict(example) # this will be updated with the function output - # remove the unwanted columns + processed_inputs = self.function(*function_args, **self.fn_kwargs) + inputs_to_merge = dict(example) + # this logic mimics the one in Dataset.map if self.remove_columns: for c in self.remove_columns: - del transformed_example[c] - transformed_example.update(self.function(*function_args, **self.fn_kwargs)) + if c in inputs_to_merge: + del inputs_to_merge[c] + if processed_inputs == inputs and c in processed_inputs: + del processed_inputs[c] + transformed_example = {**inputs_to_merge, **processed_inputs} current_idx += 1 if self._state_dict: self._state_dict["previous_state_example_idx"] += 1 diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 515dd115508..86c27e076bf 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -545,9 +545,7 @@ def test_mapped_examples_iterable_with_indices(n, func, batched, batch_size): "n, func, batched, batch_size, remove_columns", [ (3, lambda x: {"id+1": x["id"] + 1}, False, None, ["extra_column"]), # just add 1 to the id - (3, lambda x: {"id+1": x["id"] + 1, "extra_column": "bar"}, False, None, ["extra_column"]), # just add 1 to the id and add back a remove column (25, lambda x: {"id+1": [i + 1 for i in x["id"]]}, True, 10, ["extra_column"]), # same with bs=10 - (25, lambda x: {"id+1": [i + 1 for i in x["id"]], "extra_column": ["bar"]*len(x["id"])}, True, 10, ["extra_column"]), # same with bs=10 and add back a remove column ( 50, lambda x: {"foo": ["bar"] * np.random.default_rng(x["id"][0]).integers(0, 10)}, @@ -555,17 +553,8 @@ def test_mapped_examples_iterable_with_indices(n, func, batched, batch_size): 8, ["extra_column", "id"], ), # make a duplicate of each example - ( - 50, - lambda x: (lambda n: {"foo": ["bar"]*n, "extra_column": ["bar"]*n})(np.random.default_rng(x["id"][0]).integers(0, 10)), - True, - 8, - ["extra_column", "id"], - ), # make a duplicate of each example and add back a remove column (5, lambda x: {"id+1": [i + 1 for i in x["id"]]}, True, None, ["extra_column"]), # same with bs=None - (5, lambda x: {"id+1": [i + 1 for i in x["id"]], "extra_column": ["bar"]*len(x["id"])}, True, None, ["extra_column"]), # same with bs=None and add back a remove column (5, lambda x: {"id+1": [i + 1 for i in x["id"]]}, True, -1, ["extra_column"]), # same with bs<=0 - (5, lambda x: {"id+1": [i + 1 for i in x["id"]], "extra_column": ["bar"]*len(x["id"])}, True, -1, ["extra_column"]) # same with bs<=0 and add back a remove column ], ) def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, remove_columns): @@ -596,6 +585,50 @@ def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, r assert_load_state_dict_resumes_iteration(ex_iterable) +@pytest.mark.parametrize('batched', [False, True]) +@pytest.mark.parametrize('batch_size', [None, 2, 3]) +@pytest.mark.parametrize('input_columns', [None, ["i"]]) +@pytest.mark.parametrize('remove_columns', [None, ["i"]]) +@pytest.mark.parametrize('new_output', [False, True]) +def test_iterable_dataset_vs_dataset( + batched, batch_size, input_columns, remove_columns, new_output +): + if input_columns is not None and not new_output: + return + + ds1 = Dataset.from_list([{'i': i} for i in range(4)]) + + if batched: + f1 = lambda i: {'i': [j+1 for j in i]} + else: + f1 = lambda i: {'i': i+1} + + if input_columns is None: + f2 = lambda x: f1(x['i']) + else: + f2 = f1 + + if new_output: + f = f2 + else: + def f(x): + x['i'] = f2(x)['i'] + return x + + r = [ + list(ds2.map( + f, + batch_size=batch_size, + batched = batched, + remove_columns=remove_columns, + input_columns=input_columns, + )) + for ds2 in [ds1, ds1.to_iterable_dataset()] + ] + r[1] = [x for x in r[1] if len(x)>0] + assert len(r[0]) == len(r[1]) + assert all(x==y for x, y in zip(*r)) + @pytest.mark.parametrize( "n, func, batched, batch_size, fn_kwargs", [ From 16b751b862341808c69d80f36688b7cc6854f27a Mon Sep 17 00:00:00 2001 From: Vladimir Trifonov Date: Sat, 4 Jan 2025 14:20:45 -0800 Subject: [PATCH 3/7] test_iterable_dataset --- tests/test_iterable_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 86c27e076bf..6d483c484fb 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -585,6 +585,7 @@ def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, r assert_load_state_dict_resumes_iteration(ex_iterable) +# issue #7345 and PR #7353 @pytest.mark.parametrize('batched', [False, True]) @pytest.mark.parametrize('batch_size', [None, 2, 3]) @pytest.mark.parametrize('input_columns', [None, ["i"]]) From 6d2eaa77fcdba26673e4e5771ac66661c7ebd99d Mon Sep 17 00:00:00 2001 From: Vladimir Trifonov Date: Sat, 4 Jan 2025 14:27:51 -0800 Subject: [PATCH 4/7] test_iterable_dataset --- tests/test_iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 6d483c484fb..c789e5ae67e 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -591,7 +591,7 @@ def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, r @pytest.mark.parametrize('input_columns', [None, ["i"]]) @pytest.mark.parametrize('remove_columns', [None, ["i"]]) @pytest.mark.parametrize('new_output', [False, True]) -def test_iterable_dataset_vs_dataset( +def test_iterable_dataset_vs_dataset_map( batched, batch_size, input_columns, remove_columns, new_output ): if input_columns is not None and not new_output: From 45901ec5759ff85b9396308a703df9adddf3fa59 Mon Sep 17 00:00:00 2001 From: vttrifonov Date: Mon, 6 Jan 2025 21:25:34 -0800 Subject: [PATCH 5/7] Update src/datasets/iterable_dataset.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 6bff76e5695..8076a6929a9 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1051,7 +1051,7 @@ def _iter(self): for c in self.remove_columns: if c in inputs_to_merge: del inputs_to_merge[c] - if processed_inputs == inputs and c in processed_inputs: + if processed_inputs is inputs and c in processed_inputs: del processed_inputs[c] transformed_batch = {**inputs_to_merge, **processed_inputs} if transformed_batch: From ede3df25d0fae354f8bb920a851f2725a49758ee Mon Sep 17 00:00:00 2001 From: vttrifonov Date: Mon, 6 Jan 2025 21:25:41 -0800 Subject: [PATCH 6/7] Update src/datasets/iterable_dataset.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 8076a6929a9..90b4f2e9fc6 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1098,7 +1098,7 @@ def _iter(self): for c in self.remove_columns: if c in inputs_to_merge: del inputs_to_merge[c] - if processed_inputs == inputs and c in processed_inputs: + if processed_inputs is inputs and c in processed_inputs: del processed_inputs[c] transformed_example = {**inputs_to_merge, **processed_inputs} current_idx += 1 From 6f8a44f9aeb2eab540611d0305818d0e0de47c24 Mon Sep 17 00:00:00 2001 From: Vladimir Trifonov Date: Mon, 6 Jan 2025 21:38:16 -0800 Subject: [PATCH 7/7] ruff happy --- src/datasets/iterable_dataset.py | 2 +- tests/test_iterable_dataset.py | 54 ++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 90b4f2e9fc6..22633845bf1 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -79,7 +79,7 @@ def _examples_to_batch(examples: List[Dict[str, Any]]) -> Dict[str, list]: def _batch_to_examples(batch: Dict[str, list]) -> Iterator[Dict[str, Any]]: """Convert a batch (dict of examples) to examples list""" - n_examples = 0 if len(batch)==0 else len(batch[next(iter(batch))]) + n_examples = 0 if len(batch) == 0 else len(batch[next(iter(batch))]) for i in range(n_examples): yield {col: array[i] for col, array in batch.items()} diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index c789e5ae67e..366b850ba17 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -586,49 +586,57 @@ def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, r # issue #7345 and PR #7353 -@pytest.mark.parametrize('batched', [False, True]) -@pytest.mark.parametrize('batch_size', [None, 2, 3]) -@pytest.mark.parametrize('input_columns', [None, ["i"]]) -@pytest.mark.parametrize('remove_columns', [None, ["i"]]) -@pytest.mark.parametrize('new_output', [False, True]) -def test_iterable_dataset_vs_dataset_map( - batched, batch_size, input_columns, remove_columns, new_output -): +@pytest.mark.parametrize("batched", [False, True]) +@pytest.mark.parametrize("batch_size", [None, 2]) +@pytest.mark.parametrize("input_columns", [None, ["i"]]) +@pytest.mark.parametrize("remove_columns", [None, ["i"]]) +@pytest.mark.parametrize("new_output", [False, True]) +def test_iterable_dataset_vs_dataset_map(batched, batch_size, input_columns, remove_columns, new_output): if input_columns is not None and not new_output: return - ds1 = Dataset.from_list([{'i': i} for i in range(4)]) + ds1 = Dataset.from_list([{"i": i} for i in range(4)]) if batched: - f1 = lambda i: {'i': [j+1 for j in i]} + + def f1(i): + return {"i": [j + 1 for j in i]} else: - f1 = lambda i: {'i': i+1} + + def f1(i): + return {"i": i + 1} if input_columns is None: - f2 = lambda x: f1(x['i']) + + def f2(x): + return f1(x["i"]) else: f2 = f1 if new_output: f = f2 else: + def f(x): - x['i'] = f2(x)['i'] + x["i"] = f2(x)["i"] return x - + r = [ - list(ds2.map( - f, - batch_size=batch_size, - batched = batched, - remove_columns=remove_columns, - input_columns=input_columns, - )) + list( + ds2.map( + f, + batch_size=batch_size, + batched=batched, + remove_columns=remove_columns, + input_columns=input_columns, + ) + ) for ds2 in [ds1, ds1.to_iterable_dataset()] ] - r[1] = [x for x in r[1] if len(x)>0] + r[1] = [x for x in r[1] if len(x) > 0] assert len(r[0]) == len(r[1]) - assert all(x==y for x, y in zip(*r)) + assert all(x == y for x, y in zip(*r)) + @pytest.mark.parametrize( "n, func, batched, batch_size, fn_kwargs",