From ef2fb358433678b322d275c0bdee3239fa6485b2 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:58:37 +0200 Subject: [PATCH] Fix resuming arrow format (#6964) * fix resuming in arrow format * one more * fix arrow resuming --- docs/source/stream.mdx | 2 +- src/datasets/iterable_dataset.py | 368 +++++++++++++++++++++---------- tests/test_iterable_dataset.py | 202 ++++++++--------- 3 files changed, 357 insertions(+), 215 deletions(-) diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx index 89c3238429b..93f52cf8e07 100644 --- a/docs/source/stream.mdx +++ b/docs/source/stream.mdx @@ -415,6 +415,6 @@ This can be used with the `StatefulDataLoader` from `torchdata`: -Resuming returns exactly where the checkpoint was saved except in two cases: 1) examples from shuffle buffers are lost when resuming and the buffers are refilled with new data and 2) combinations of `.with_format(arrow)` and batched `.map()` may skip one batch. +Resuming returns exactly where the checkpoint was saved except if `.shuffle()` is used: examples from shuffle buffers are lost when resuming and the buffers are refilled with new data. diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d49ffd47a49..ca1847d6a42 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -114,58 +114,6 @@ def _convert_to_arrow( yield new_key, pa.Table.from_pylist(cast_to_python_objects(examples, only_1d_for_numpy=True)) -def _batch_arrow_tables( - iterable: Iterable[Tuple[Key, pa.Table]], - batch_size: Optional[int], - drop_last_batch: bool = False, -) -> Iterator[Tuple[Key, pa.Table]]: - """Iterate over sub-tables of size `batch_size`. - - Args: - iterable (`Iterable[Tuple[Key, pa.Table]]`): - A tables iterable containing tuples (table_key, table) of type (int/str, pa.Table) - batch_size (`Optional[int]`): - Size of each sub-table to yield. If None or <= 0, yields the full table. - drop_last_batch (`bool`, defaults to `False`): - Drop the last batch if it is smaller than `batch_size`. - """ - if batch_size is None or batch_size <= 0: - yield "all", pa.concat_tables([pa_table for _, pa_table in iterable]) - return - keys_buffer = [] - chunks_buffer = [] - chunks_buffer_size = 0 - for key, pa_table in iterable: - for chunk in pa_table.to_reader(max_chunksize=batch_size): - if len(chunk) == 0: - continue - elif chunks_buffer_size + len(chunk) < batch_size: - keys_buffer.append(key) - chunks_buffer.append(chunk) - chunks_buffer_size += len(chunk) - continue - elif chunks_buffer_size + len(chunk) == batch_size: - keys_buffer.append(key) - chunks_buffer.append(chunk) - new_key = "_".join(str(_key) for _key in keys_buffer) - yield new_key, pa.Table.from_batches(chunks_buffer) - keys_buffer = [] - chunks_buffer = [] - chunks_buffer_size = 0 - else: - cropped_chunk_length = batch_size - chunks_buffer_size - keys_buffer.append(f"{key}[:{cropped_chunk_length}]") - chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) - new_key = "_".join(str(_key) for _key in keys_buffer) - yield new_key, pa.Table.from_batches(chunks_buffer) - keys_buffer = [f"{key}[{cropped_chunk_length}:]"] - chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] - chunks_buffer_size = len(chunk) - cropped_chunk_length - if not drop_last_batch and chunks_buffer: - new_key = "_".join(str(_key) for _key in keys_buffer) - yield new_key, pa.Table.from_batches(chunks_buffer) - - class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" @@ -204,7 +152,7 @@ def _init_state_dict(self) -> dict: def load_state_dict(self, state_dict: dict) -> dict: def _inner_load_state_dict(state, new_state): if new_state is not None and isinstance(state, dict): - for key in state: + for key in new_state: state[key] = _inner_load_state_dict(state[key], new_state[key]) return state elif new_state is not None and isinstance(state, list): @@ -432,6 +380,131 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamples ) +class RebatchedArrowExamplesIterable(_BaseExamplesIterable): + def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int], drop_last_batch: bool = False): + super().__init__() + self.ex_iterable = ex_iterable + self.batch_size = batch_size + self.drop_last_batch = drop_last_batch + + @property + def iter_arrow(self): + return self._iter_arrow + + def _init_state_dict(self) -> dict: + self._state_dict = { + "ex_iterable": self.ex_iterable._init_state_dict(), + "previous_state": None, + "batch_idx": 0, + "num_chunks_since_previous_state": 0, + "cropped_chunk_length": 0, + } + return self._state_dict + + def __iter__(self): + yield from self.ex_iterable + + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: + """Iterate over sub-tables of size `batch_size`.""" + if self._state_dict and self._state_dict["previous_state"]: + self.ex_iterable.load_state_dict(self._state_dict["previous_state"]) + if self.ex_iterable.iter_arrow: + iterator = self.ex_iterable.iter_arrow() + else: + iterator = _convert_to_arrow(self.ex_iterable, batch_size=1) + if self.batch_size is None or self.batch_size <= 0: + if self._state_dict and self._state_dict["batch_idx"] > 0: + return + all_pa_table = pa.concat_tables([pa_table for _, pa_table in iterator]) + if self._state_dict: + self._state_dict["batch_idx"] = 1 + yield "all", all_pa_table + return + keys_buffer = [] + chunks_buffer = [] + chunks_buffer_size = 0 + num_chunks_to_skip = self._state_dict["num_chunks_since_previous_state"] if self._state_dict else 0 + chunk_length_to_crop = self._state_dict["cropped_chunk_length"] if self._state_dict else 0 + if self._state_dict: + previous_state = self.ex_iterable.state_dict() + self._state_dict["previous_state"] = previous_state + for key, pa_table in iterator: + for num_chunks_since_previous_state, chunk in enumerate(pa_table.to_reader(max_chunksize=self.batch_size)): + if num_chunks_to_skip > 1: + num_chunks_to_skip -= 1 + continue + elif num_chunks_to_skip == 1 and chunk_length_to_crop == 0: + num_chunks_to_skip -= 1 + continue + elif num_chunks_to_skip == 1 and chunk_length_to_crop > 0: + chunk = chunk.slice(chunk_length_to_crop, len(chunk) - chunk_length_to_crop) + num_chunks_to_skip = 0 + chunk_length_to_crop = 0 + if len(chunk) == 0: + continue + + if chunks_buffer_size + len(chunk) < self.batch_size: + keys_buffer.append(key) + chunks_buffer.append(chunk) + chunks_buffer_size += len(chunk) + continue + elif chunks_buffer_size + len(chunk) == self.batch_size: + keys_buffer.append(key) + chunks_buffer.append(chunk) + new_key = "_".join(str(_key) for _key in keys_buffer) + if self._state_dict: + self._state_dict["batch_idx"] += 1 + self._state_dict["num_chunks_since_previous_state"] += len(chunks_buffer) + self._state_dict["cropped_chunk_length"] = 0 + yield new_key, pa.Table.from_batches(chunks_buffer) + keys_buffer = [] + chunks_buffer = [] + chunks_buffer_size = 0 + if self._state_dict: + self._state_dict["previous_state"] = previous_state + self._state_dict["num_chunks_since_previous_state"] = num_chunks_since_previous_state + 1 + else: + cropped_chunk_length = self.batch_size - chunks_buffer_size + keys_buffer.append(f"{key}[:{cropped_chunk_length}]") + chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) + new_key = "_".join(str(_key) for _key in keys_buffer) + if self._state_dict: + self._state_dict["batch_idx"] += 1 + self._state_dict["num_chunks_since_previous_state"] += len(chunks_buffer) + self._state_dict["cropped_chunk_length"] = cropped_chunk_length + yield new_key, pa.Table.from_batches(chunks_buffer) + keys_buffer = [f"{key}[{cropped_chunk_length}:]"] + chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] + chunks_buffer_size = len(chunk) - cropped_chunk_length + if self._state_dict: + self._state_dict["previous_state"] = previous_state + self._state_dict["num_chunks_since_previous_state"] = num_chunks_since_previous_state + if self._state_dict: + previous_state = self.ex_iterable.state_dict() + if not self.drop_last_batch and chunks_buffer: + new_key = "_".join(str(_key) for _key in keys_buffer) + if self._state_dict: + self._state_dict["previous_state"] = previous_state + self._state_dict["batch_idx"] += 1 + self._state_dict["num_chunks_since_previous_state"] = 0 + self._state_dict["cropped_chunk_length"] = 0 + yield new_key, pa.Table.from_batches(chunks_buffer) + + def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArrowExamplesIterable": + return RebatchedArrowExamplesIterable( + self.ex_iterable.shuffle_data_sources(generator), self.batch_size, self.drop_last_batch + ) + + def shard_data_sources(self, worker_id: int, num_workers: int) -> "RebatchedArrowExamplesIterable": + return RebatchedArrowExamplesIterable( + self.ex_iterable.shard_data_sources(worker_id, num_workers), self.batch_size, self.drop_last_batch + ) + + @property + def n_shards(self) -> int: + return self.ex_iterable.n_shards + + class SelectColumnsIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: List[str]): super().__init__() @@ -841,6 +914,19 @@ def __init__( self.input_columns = input_columns self.fn_kwargs = fn_kwargs or {} self.formatting = formatting + # sanity checks + if formatting and formatting.format_type == "arrow": + # batch_size should match for iter_arrow + if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): + raise ValueError( + "The Arrow-formatted MappedExamplesIterable has underlying iterable" + f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." + ) + elif ex_iterable.batch_size != (batch_size if batched else 1): + raise ValueError( + f"The Arrow-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"different from {ex_iterable.batch_size=} from its underlying iterable." + ) @property def iter_arrow(self): @@ -858,7 +944,9 @@ def _init_state_dict(self) -> dict: def __iter__(self): if self.formatting and self.formatting.format_type == "arrow": - yield from ArrowExamplesIterable(self._iter_arrow, {}) + formatter = PythonFormatter() + for key, pa_table in self._iter_arrow(max_chunksize=1): + yield key, formatter.format_row(pa_table) else: yield from self._iter() @@ -961,32 +1049,32 @@ def _iter(self): self._state_dict["previous_state_example_idx"] += 1 yield key, transformed_example - def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: + def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key, pa.Table]]: if self.ex_iterable.iter_arrow: - iterator = _batch_arrow_tables( - self.ex_iterable.iter_arrow(), - batch_size=self.batch_size if self.batched else 1, - drop_last_batch=self.drop_last_batch, - ) + iterator = self.ex_iterable.iter_arrow() else: iterator = _convert_to_arrow( self.ex_iterable, batch_size=self.batch_size if self.batched else 1, drop_last_batch=self.drop_last_batch, ) - - current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 if self._state_dict and self._state_dict["previous_state"]: self.ex_iterable.load_state_dict(self._state_dict["previous_state"]) num_examples_to_skip = self._state_dict["num_examples_since_previous_state"] else: num_examples_to_skip = 0 - if self._state_dict: + if self._state_dict and max_chunksize is not None: self._state_dict["previous_state"] = self.ex_iterable.state_dict() self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx - + current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 for key, pa_table in iterator: + if ( + self.batched + and self.batch_size is not None + and len(pa_table) < self.batch_size + and self.drop_last_batch + ): + return # first build the batch function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] if self.with_indices: @@ -1007,17 +1095,24 @@ def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: if column in output_table.column_names: output_table = output_table.remove_column(output_table.column_names.index(column)) # return output - current_idx += len(pa_table) - if self._state_dict: - self._state_dict["num_examples_since_previous_state"] += len(pa_table) - if num_examples_to_skip > 0: - num_examples_to_skip -= len(pa_table) - continue - yield key, output_table - if self._state_dict: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx + if max_chunksize is None: + current_idx += len(pa_table) + if self._state_dict: + self._state_dict["previous_state_example_idx"] += len(pa_table) + yield key, output_table + else: + for i, pa_subtable in enumerate(output_table.to_reader(max_chunksize=max_chunksize)): + current_idx += 1 + if self._state_dict: + self._state_dict["num_examples_since_previous_state"] += 1 + if num_examples_to_skip > 0: + num_examples_to_skip -= 1 + continue + yield f"{key}_{i}", pa_subtable + if self._state_dict: + self._state_dict["previous_state"] = self.ex_iterable.state_dict() + self._state_dict["num_examples_since_previous_state"] = 0 + self._state_dict["previous_state_example_idx"] += len(pa_table) def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExamplesIterable": """Shuffle the wrapped examples iterable.""" @@ -1081,6 +1176,19 @@ def __init__( self.input_columns = input_columns self.fn_kwargs = fn_kwargs or {} self.formatting = formatting + # sanity checks + if formatting and formatting.format_type == "arrow": + # batch_size should match for iter_arrow + if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): + raise ValueError( + "The Arrow-formatted FilteredExamplesIterable has underlying iterable" + f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." + ) + elif ex_iterable.batch_size != (batch_size if batched else 1): + raise ValueError( + f"The Arrow-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"different from {ex_iterable.batch_size=} from its underlying iterable." + ) @property def iter_arrow(self): @@ -1098,7 +1206,9 @@ def _init_state_dict(self) -> dict: def __iter__(self): if self.formatting and self.formatting.format_type == "arrow": - yield from ArrowExamplesIterable(self._iter_arrow, {}) + formatter = PythonFormatter() + for key, pa_table in self._iter_arrow(max_chunksize=1): + yield key, formatter.format_row(pa_table) else: yield from self._iter() @@ -1170,26 +1280,29 @@ def _iter(self): if to_keep: yield key, example - def _iter_arrow(self): + def _iter_arrow(self, max_chunksize: Optional[int] = None): if self.ex_iterable.iter_arrow: - iterator = _batch_arrow_tables( - self.ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1 - ) + iterator = self.ex_iterable.iter_arrow() else: iterator = _convert_to_arrow(self.ex_iterable, batch_size=self.batch_size if self.batched else 1) - current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 if self._state_dict and self._state_dict["previous_state"]: self.ex_iterable.load_state_dict(self._state_dict["previous_state"]) num_examples_to_skip = self._state_dict["num_examples_since_previous_state"] else: num_examples_to_skip = 0 - if self._state_dict: + if self._state_dict and max_chunksize is not None: self._state_dict["previous_state"] = self.ex_iterable.state_dict() self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx - + current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 for key, pa_table in iterator: + if ( + self.batched + and self.batch_size is not None + and len(pa_table) < self.batch_size + and self.drop_last_batch + ): + return # first build the batch function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] if self.with_indices: @@ -1199,21 +1312,33 @@ def _iter_arrow(self): function_args.append(current_idx) # then apply the transform mask = self.function(*function_args, **self.fn_kwargs) - # yield the filtered table - current_idx += len(pa_table) - if self._state_dict: - self._state_dict["num_examples_since_previous_state"] += len(pa_table) - if num_examples_to_skip > 0: - num_examples_to_skip -= len(pa_table) - continue + # return output if self.batched: - yield key, pa_table.filter(mask) + output_table = pa_table.filter(mask) elif mask.as_py() if isinstance(mask, pa.BooleanScalar) else mask: - yield key, pa_table - if self._state_dict: - self._state_dict["previous_state"] = self.ex_iterable.state_dict() - self._state_dict["num_examples_since_previous_state"] = 0 - self._state_dict["previous_state_example_idx"] = current_idx + output_table = pa_table + else: + output_table = pa_table.slice(0, 0) + + if max_chunksize is None: + current_idx += len(pa_table) + if self._state_dict: + self._state_dict["previous_state_example_idx"] += len(pa_table) + if len(output_table) > 0: + yield key, output_table + else: + for i, pa_subtable in enumerate(output_table.to_reader(max_chunksize=max_chunksize)): + current_idx += 1 + if self._state_dict: + self._state_dict["num_examples_since_previous_state"] += 1 + if num_examples_to_skip > 0: + num_examples_to_skip -= 1 + continue + yield f"{key}_{i}", pa_subtable + if self._state_dict: + self._state_dict["previous_state"] = self.ex_iterable.state_dict() + self._state_dict["num_examples_since_previous_state"] = 0 + self._state_dict["previous_state_example_idx"] += len(pa_table) def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable": """Shuffle the wrapped examples iterable.""" @@ -1536,8 +1661,9 @@ def __init__( self._distributed = distributed self._epoch = 0 self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} - self._state_dict = ex_iterable._init_state_dict() self._starting_state_dict: Optional[dict] = None + self._prepared_ex_iterable = self._prepare_ex_iterable_for_iteration() + self._state_dict = self._prepared_ex_iterable._init_state_dict() _maybe_add_torch_iterable_dataset_parent_class(self.__class__) def state_dict(self) -> dict: @@ -1641,7 +1767,7 @@ def load_state_dict(self, state_dict: dict) -> None: >>> dataloader.load_state_dict(state_dict) # uses ds.load_state_dict() under the hood ``` """ - self._ex_iterable.load_state_dict(state_dict) + self._prepared_ex_iterable.load_state_dict(state_dict) self._starting_state_dict = state_dict def __repr__(self): @@ -1716,7 +1842,7 @@ def _iter_pytorch(self): if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"): if ex_iterable.iter_arrow: - iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=1) + iterator = ex_iterable.iter_arrow() else: iterator = _convert_to_arrow(ex_iterable, batch_size=1) for key, pa_table in iterator: @@ -1750,11 +1876,18 @@ def _is_main_process(self): return False return True - def _prepare_ex_iterable_for_iteration(self) -> _BaseExamplesIterable: + def _prepare_ex_iterable_for_iteration( + self, batch_size: int = 1, drop_last_batch: bool = False + ) -> _BaseExamplesIterable: + ex_iterable = self._ex_iterable + if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"): + ex_iterable = RebatchedArrowExamplesIterable( + ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch + ) if self._shuffling: - ex_iterable = self._ex_iterable.shuffle_data_sources(self._effective_generator()) + ex_iterable = ex_iterable.shuffle_data_sources(self._effective_generator()) else: - ex_iterable = self._ex_iterable + ex_iterable = ex_iterable if self._distributed: rank = self._distributed.rank @@ -1779,6 +1912,9 @@ def _prepare_ex_iterable_for_iteration(self) -> _BaseExamplesIterable: ) ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank) + self._state_dict = ex_iterable._init_state_dict() + if self._starting_state_dict: + ex_iterable.load_state_dict(self._starting_state_dict) return ex_iterable def __iter__(self): @@ -1792,9 +1928,6 @@ def __iter__(self): return ex_iterable = self._prepare_ex_iterable_for_iteration() - self._state_dict = ex_iterable._init_state_dict() - if self._starting_state_dict: - ex_iterable.load_state_dict(self._starting_state_dict) if self._formatting: formatter = get_formatter(self._formatting.format_type, features=self.features) format_dict = ( @@ -1804,9 +1937,8 @@ def __iter__(self): format_dict = None if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"): - assert self._state_dict is ex_iterable._state_dict if ex_iterable.iter_arrow: - iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=1) + iterator = ex_iterable.iter_arrow() else: iterator = _convert_to_arrow(ex_iterable, batch_size=1) for key, pa_table in iterator: @@ -1839,12 +1971,10 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): else: format_dict = None - ex_iterable = self._prepare_ex_iterable_for_iteration() + ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=batch_size, drop_last_batch=drop_last_batch) if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"): if ex_iterable.iter_arrow: - iterator = _batch_arrow_tables( - ex_iterable.iter_arrow(), batch_size=batch_size, drop_last_batch=drop_last_batch - ) + iterator = ex_iterable.iter_arrow() else: iterator = _convert_to_arrow(ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch) for key, pa_table in iterator: @@ -2092,10 +2222,18 @@ def map( function = identity_func if fn_kwargs is None: fn_kwargs = {} - ex_iterable = MappedExamplesIterable( + ex_iterable = ( TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id) if self._info.features is not None - else self._ex_iterable, + else self._ex_iterable + ) + ex_iterable = ( + RebatchedArrowExamplesIterable(ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch) + if self._formatting and self._formatting.format_type == "arrow" + else ex_iterable + ) + ex_iterable = MappedExamplesIterable( + ex_iterable, function=function, with_indices=with_indices, input_columns=input_columns, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 67f1806fe12..4ca143799f7 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -29,6 +29,7 @@ IterableDataset, MappedExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, + RebatchedArrowExamplesIterable, SelectColumnsIterable, ShuffledDataSourcesArrowExamplesIterable, ShuffledDataSourcesExamplesIterable, @@ -39,7 +40,6 @@ TypedExamplesIterable, VerticallyConcatenatedMultiSourcesExamplesIterable, _BaseExamplesIterable, - _batch_arrow_tables, _batch_to_examples, _convert_to_arrow, _examples_to_batch, @@ -116,29 +116,20 @@ def arrow_file(tmp_path_factory, dataset: IterableDataset): return filename -def assert_load_state_dict_resumes_iteration(ex_iterable: _BaseExamplesIterable, num_allowed_lost_samples=0): +def assert_load_state_dict_resumes_iteration(ex_iterable: _BaseExamplesIterable): ex_iterable._init_state_dict() state_dicts = [ex_iterable.state_dict()] examples = [] for _, example in ex_iterable: state_dicts.append(ex_iterable.state_dict()) examples.append(example) - if num_allowed_lost_samples < 0: - num_allowed_lost_samples = len(examples) for i, state_dict in enumerate(state_dicts): ex_iterable.load_state_dict(state_dict) examples_after_resuming = [example for _, example in ex_iterable] - if num_allowed_lost_samples > 0: - for j in range(num_allowed_lost_samples): - if examples_after_resuming == examples[i + j :]: - break - else: - raise AssertionError(f"Failed to resume iteration, even with {num_allowed_lost_samples=}") - else: - assert examples_after_resuming == examples[i:] + assert examples_after_resuming == examples[i:], f"resuming from idx {i} with {state_dict=}" -def assert_load_state_dict_resumes_arrow_iteration(ex_iterable: _BaseExamplesIterable, num_allowed_lost_samples=0): +def assert_load_state_dict_resumes_arrow_iteration(ex_iterable: _BaseExamplesIterable): assert ex_iterable.iter_arrow is not None ex_iterable._init_state_dict() state_dicts = [ex_iterable.state_dict()] @@ -148,23 +139,12 @@ def assert_load_state_dict_resumes_arrow_iteration(ex_iterable: _BaseExamplesIte state_dicts.append(ex_iterable.state_dict()) examples.extend(pa_table.to_pylist()) indices.append(indices[-1] + len(pa_table)) - if num_allowed_lost_samples < 0: - num_allowed_lost_samples = len(examples) for i, state_dict in zip(indices, state_dicts): ex_iterable.load_state_dict(state_dict) examples_after_resuming = [ example for _, pa_table in ex_iterable.iter_arrow() for example in pa_table.to_pylist() ] - if num_allowed_lost_samples > 0: - for j in range(num_allowed_lost_samples): - if examples_after_resuming == examples[i + j :]: - break - else: - raise AssertionError( - f"Failed to resume iteration from state {i}, even with {num_allowed_lost_samples=}" - ) - else: - assert examples_after_resuming == examples[i:] + assert examples_after_resuming == examples[i:], f"resuming from idx {i} with {state_dict=}" ################################ @@ -199,34 +179,6 @@ def test_convert_to_arrow(batch_size, drop_last_batch): assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() -@pytest.mark.parametrize( - "tables", - [ - [pa.table({"foo": range(10)})], - [pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})], - [pa.table({"foo": [i]}) for i in range(10)], - ], -) -@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) -@pytest.mark.parametrize("drop_last_batch", [False, True]) -def test_batch_arrow_tables(tables, batch_size, drop_last_batch): - full_table = pa.concat_tables(tables) - num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size - num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - subtables = list( - _batch_arrow_tables(list(enumerate(tables)), batch_size=batch_size, drop_last_batch=drop_last_batch) - ) - assert len(subtables) == num_batches - if drop_last_batch: - assert all(len(subtable) == batch_size for _, subtable in subtables) - else: - assert all(len(subtable) == batch_size for _, subtable in subtables[:-1]) - assert len(subtables[-1][1]) <= batch_size - if num_rows > 0: - reloaded = pa.concat_tables([subtable for _, subtable in subtables]) - assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() - - ################################ # # _BaseExampleIterable tests @@ -315,6 +267,42 @@ def test_arrow_examples_iterable_shuffle_data_sources(): assert_load_state_dict_resumes_iteration(ex_iterable) +@pytest.mark.parametrize( + "tables", + [ + [pa.table({"foo": range(10)})], + [pa.table({"foo": range(5 * i, 5 * (i + 1))}) for i in range(2)], + [pa.table({"foo": range(5 * i, 5 * (i + 1))}) for i in range(7)], + [pa.table({"foo": [i]}) for i in range(10)], + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 7, 9, 10, 11, 13, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_rebatched_arrow_examples_iterable(tables, batch_size, drop_last_batch): + full_table = pa.concat_tables(tables) + num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + + def gen(tables): + for i, table in enumerate(tables): + yield str(i), table + + ex_iterable = ArrowExamplesIterable(gen, {"tables": tables}) + ex_iterable = RebatchedArrowExamplesIterable(ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch) + subtables = list(ex_iterable.iter_arrow()) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for _, subtable in subtables) + else: + assert all(len(subtable) == batch_size for _, subtable in subtables[:-1]) + assert len(subtables[-1][1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables([subtable for _, subtable in subtables]) + assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) + + @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) def test_buffer_shuffled_examples_iterable(seed): n, buffer_size = 100, 30 @@ -685,6 +673,7 @@ def test_mapped_examples_iterable_input_columns(n, func, batched, batch_size, in ) def test_mapped_examples_iterable_arrow_format(n, func, batched, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, @@ -706,13 +695,48 @@ def test_mapped_examples_iterable_arrow_format(n, func, batched, batch_size): expected.extend(func(batch).to_pylist()) assert next(iter(ex_iterable))[1] == expected[0] assert [x for _, x in ex_iterable] == expected - if not batched: - num_allowed_lost_samples = 0 - elif batch_size is None or batch_size < 0: - num_allowed_lost_samples = -1 + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) + + +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), False, None), # just add 1 to the id + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 1), # same with bs=1 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1), # same with bs<=0 + (3, lambda t: pa.concat_tables([t] * 2), True, 1), # make a duplicate of each example + ], +) +def test_mapped_examples_iterable_arrow_format_from_arrow_examples_iterable(n, func, batched, batch_size): + base_ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"n": n}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, + func, + batched=batched, + batch_size=batch_size, + formatting=FormattingConfig(format_type="arrow"), + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batched is False: + expected = [func(pa.Table.from_pylist([x])).to_pylist()[0] for x in all_examples] else: - num_allowed_lost_samples = batch_size + 1 - assert_load_state_dict_resumes_iteration(ex_iterable, num_allowed_lost_samples=num_allowed_lost_samples) + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) @pytest.mark.parametrize( @@ -729,6 +753,7 @@ def test_mapped_examples_iterable_arrow_format(n, func, batched, batch_size): ) def test_mapped_examples_iterable_drop_last_batch_and_arrow_format(n, func, batched, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, @@ -791,6 +816,7 @@ def test_mapped_examples_iterable_drop_last_batch_and_arrow_format(n, func, batc ) def test_mapped_examples_iterable_with_indices_and_arrow_format(n, func, batched, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, @@ -813,13 +839,8 @@ def test_mapped_examples_iterable_with_indices_and_arrow_format(n, func, batched expected.extend(func(batch, list(range(batch_offset, batch_offset + len(batch)))).to_pylist()) assert next(iter(ex_iterable))[1] == expected[0] assert [x for _, x in ex_iterable] == expected - if not batched: - num_allowed_lost_samples = 0 - elif batch_size is None or batch_size < 0: - num_allowed_lost_samples = -1 - else: - num_allowed_lost_samples = batch_size + 1 - assert_load_state_dict_resumes_iteration(ex_iterable, num_allowed_lost_samples=num_allowed_lost_samples) + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) @pytest.mark.parametrize( @@ -846,6 +867,7 @@ def test_mapped_examples_iterable_with_indices_and_arrow_format(n, func, batched ) def test_mapped_examples_iterable_remove_columns_arrow_format(n, func, batched, batch_size, remove_columns): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "extra_column": "foo"}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, @@ -874,13 +896,8 @@ def test_mapped_examples_iterable_remove_columns_arrow_format(n, func, batched, ) assert next(iter(ex_iterable))[1] == expected[0] assert [x for _, x in ex_iterable] == expected - if not batched: - num_allowed_lost_samples = 0 - elif batch_size is None or batch_size < 0: - num_allowed_lost_samples = -1 - else: - num_allowed_lost_samples = batch_size + 1 - assert_load_state_dict_resumes_iteration(ex_iterable, num_allowed_lost_samples=num_allowed_lost_samples) + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) @pytest.mark.parametrize( @@ -895,6 +912,7 @@ def test_mapped_examples_iterable_remove_columns_arrow_format(n, func, batched, ) def test_mapped_examples_iterable_fn_kwargs_and_arrow_format(n, func, batched, batch_size, fn_kwargs): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, @@ -919,13 +937,8 @@ def test_mapped_examples_iterable_fn_kwargs_and_arrow_format(n, func, batched, b expected.extend(func(batch, **fn_kwargs).to_pylist()) assert next(iter(ex_iterable))[1] == expected[0] assert [x for _, x in ex_iterable] == expected - if not batched: - num_allowed_lost_samples = 0 - elif batch_size is None or batch_size < 0: - num_allowed_lost_samples = -1 - else: - num_allowed_lost_samples = batch_size + 1 - assert_load_state_dict_resumes_iteration(ex_iterable, num_allowed_lost_samples=num_allowed_lost_samples) + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) @pytest.mark.parametrize( @@ -939,6 +952,7 @@ def test_mapped_examples_iterable_fn_kwargs_and_arrow_format(n, func, batched, b ) def test_mapped_examples_iterable_input_columns_and_arrow_format(n, func, batched, batch_size, input_columns): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + base_ex_iterable = RebatchedArrowExamplesIterable(base_ex_iterable, batch_size=batch_size if batched else 1) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, @@ -964,13 +978,8 @@ def test_mapped_examples_iterable_input_columns_and_arrow_format(n, func, batche expected.extend(func(*[batch[col] for col in columns_to_input]).to_pylist()) assert next(iter(ex_iterable))[1] == expected[0] assert [x for _, x in ex_iterable] == expected - if not batched: - num_allowed_lost_samples = 0 - elif batch_size is None or batch_size < 0: - num_allowed_lost_samples = -1 - else: - num_allowed_lost_samples = batch_size + 1 - assert_load_state_dict_resumes_iteration(ex_iterable, num_allowed_lost_samples=num_allowed_lost_samples) + assert_load_state_dict_resumes_iteration(ex_iterable) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) @pytest.mark.parametrize( @@ -1193,20 +1202,22 @@ def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable): # HorizontallyConcatenatedMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), # not implemented # RandomlyCyclingMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})], np.random.default_rng(42)), # not implemented MappedExamplesIterable( - ExamplesIterable(generate_examples_fn, {}), lambda t: t, formatting=FormattingConfig(format_type="arrow") + RebatchedArrowExamplesIterable(ExamplesIterable(generate_examples_fn, {}), batch_size=1), + lambda t: t, + formatting=FormattingConfig(format_type="arrow"), ), MappedExamplesIterable( - ArrowExamplesIterable(generate_tables_fn, {}), + RebatchedArrowExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), batch_size=1), lambda t: t, formatting=FormattingConfig(format_type="arrow"), ), FilteredExamplesIterable( - ExamplesIterable(generate_examples_fn, {}), + RebatchedArrowExamplesIterable(ExamplesIterable(generate_examples_fn, {}), batch_size=1), lambda t: True, formatting=FormattingConfig(format_type="arrow"), ), FilteredExamplesIterable( - ArrowExamplesIterable(generate_tables_fn, {}), + RebatchedArrowExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), batch_size=1), lambda t: True, formatting=FormattingConfig(format_type="arrow"), ), @@ -1222,14 +1233,7 @@ def test_iter_arrow(ex_iterable: _BaseExamplesIterable): assert ex_iterable.iter_arrow is not None key, pa_table = next(ex_iterable.iter_arrow()) assert isinstance(pa_table, pa.Table) - if ( - isinstance(ex_iterable, (MappedExamplesIterable, FilteredExamplesIterable)) - and ex_iterable.ex_iterable.iter_arrow - ): - num_allowed_lost_samples = len(next(iter(ex_iterable.ex_iterable.iter_arrow()))[1]) - else: - num_allowed_lost_samples = 0 - assert_load_state_dict_resumes_arrow_iteration(ex_iterable, num_allowed_lost_samples=num_allowed_lost_samples) + assert_load_state_dict_resumes_arrow_iteration(ex_iterable) ############################