diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index ed52f33b..e06d8ae3 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -40,7 +40,6 @@ from absl import logging from array_record.python.array_record_data_source import PathLikeOrFileInstruction from grain._src.python.data_loader import _determine_worker_count -from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import packing from grain._src.python.dataset.transformations import slice as slice_dataset from jax.experimental import multihost_utils @@ -181,8 +180,7 @@ class _UnbatchDatasetIterator(grain.DatasetIterator): """An iterator that unbatches np.arrays along dim=0.""" def __init__(self, parent: grain.DatasetIterator): - super().__init__(stats=None) - self._parent = parent + super().__init__(parent) # Index within the unbatched inputs. self._index = 0 self._current_batch = None @@ -467,10 +465,8 @@ def __init__( *, pad_example: Any, length: int, - stats: dataset_stats.Stats, ): - super().__init__(stats) - self._parent = parent + super().__init__(parent) self._pad_example = pad_example self._length = length self._i = 0 @@ -516,7 +512,6 @@ def __iter__(self): parent_iter, pad_example=self._pad_example, length=self._length, - stats=self._stats, ) diff --git a/axlearn/common/input_grain_lm.py b/axlearn/common/input_grain_lm.py index 68b36463..d408cfc1 100644 --- a/axlearn/common/input_grain_lm.py +++ b/axlearn/common/input_grain_lm.py @@ -3,10 +3,12 @@ """Input processing for language modeling using Grain.""" import functools +import logging import sys from typing import Optional, Protocol import numpy as np +from grain._src.python.dataset.transformations.prefetch import MultiprocessPrefetchIterDataset from axlearn.common import input_grain, input_grain_text from axlearn.common.config import ConfigOr, maybe_instantiate @@ -124,8 +126,11 @@ def text_to_lm_training_input( split_fn = functools.partial( _trim_or_pad_and_batch, max_padding_fraction=max_padding_fraction, pad_id=vocab.pad_id ) - # Only repeat if not already infinite. - if len(ds) != sys.maxsize: + if isinstance(ds, MultiprocessPrefetchIterDataset): + # Dataset types like MultiprocessPrefetchIterDataset have no len() or repeat() function + logging.info("Skipping repeat for ds: %s`", ds) + elif len(ds) != sys.maxsize: + # Only repeat if not already infinite. ds = ds.repeat(num_epochs=None) ds = input_grain_text.tokenize(ds, vocab={"text": vocab}, with_eos=True) ds = input_grain.rekey(ds, key_map={"target_labels": "text"}) diff --git a/axlearn/common/input_grain_lm_test.py b/axlearn/common/input_grain_lm_test.py index deb1d189..5ca0704e 100644 --- a/axlearn/common/input_grain_lm_test.py +++ b/axlearn/common/input_grain_lm_test.py @@ -201,6 +201,31 @@ def test_training_lm_processor_infinite_dataset(self): max_padding_fraction=0.0, ) + @pytest.mark.skipif( + not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata." + ) + def test_training_lm_processor_dataset_no_len_func(self): + max_len = 32 + vocab = seqio.SentencePieceVocabulary( + sentencepiece_model_file=t5_sentence_piece_vocab_file, + ) + examples = [{"text": f"test_str_#{i}", "index": i} for i in range(10)] + ds = prefetch_dataset( + maybe_to_iter_dataset(fake_grain_source(examples)), + multiprocessing_options=grain.MultiprocessingOptions( + num_workers=1, + per_worker_buffer_size=1, + enable_profiling=False, + ), + ) + ds = text_to_lm_training_input( + ds, # check if prefetch dataset breaks the pipeline + vocab=vocab, + max_len=max_len, + window_size=3, + max_padding_fraction=0.0, + ) + @parameterized.parameters( dict( expected_batches=[ diff --git a/axlearn/common/input_grain_test.py b/axlearn/common/input_grain_test.py index 7bf1f866..4cfed0e7 100644 --- a/axlearn/common/input_grain_test.py +++ b/axlearn/common/input_grain_test.py @@ -258,7 +258,7 @@ def convert_examples(x, rng: np.random.Generator): ds = range_dataset(start=1, stop=10) ds = ds.repeat(None).batch(3) - ds = ds.map(convert_examples, seed=123) + ds = ds.random_map(convert_examples, seed=123) ds = unbatch(maybe_to_iter_dataset(ds)) ds = iter(ds) self._test_checkpointing(ds) diff --git a/pyproject.toml b/pyproject.toml index 92414e36..7fb7bd4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,7 @@ orbax = [ ] # Grain input processing. Currently does not support macos. grain = [ - "grain==0.2.1; platform_machine == 'x86_64'", + "grain==0.2.3; platform_machine == 'x86_64'", ] # Audio dependencies. audio = [