Skip to content

Commit

Permalink
Bump grain version, new test and fixes. (#921)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh authored Jan 14, 2025
1 parent a946f91 commit bd6d32b
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
9 changes: 2 additions & 7 deletions axlearn/common/input_grain.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from absl import logging
from array_record.python.array_record_data_source import PathLikeOrFileInstruction
from grain._src.python.data_loader import _determine_worker_count
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import packing
from grain._src.python.dataset.transformations import slice as slice_dataset
from jax.experimental import multihost_utils
Expand Down Expand Up @@ -181,8 +180,7 @@ class _UnbatchDatasetIterator(grain.DatasetIterator):
"""An iterator that unbatches np.arrays along dim=0."""

def __init__(self, parent: grain.DatasetIterator):
super().__init__(stats=None)
self._parent = parent
super().__init__(parent)
# Index within the unbatched inputs.
self._index = 0
self._current_batch = None
Expand Down Expand Up @@ -467,10 +465,8 @@ def __init__(
*,
pad_example: Any,
length: int,
stats: dataset_stats.Stats,
):
super().__init__(stats)
self._parent = parent
super().__init__(parent)
self._pad_example = pad_example
self._length = length
self._i = 0
Expand Down Expand Up @@ -516,7 +512,6 @@ def __iter__(self):
parent_iter,
pad_example=self._pad_example,
length=self._length,
stats=self._stats,
)


Expand Down
9 changes: 7 additions & 2 deletions axlearn/common/input_grain_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""Input processing for language modeling using Grain."""

import functools
import logging
import sys
from typing import Optional, Protocol

import numpy as np
from grain._src.python.dataset.transformations.prefetch import MultiprocessPrefetchIterDataset

from axlearn.common import input_grain, input_grain_text
from axlearn.common.config import ConfigOr, maybe_instantiate
Expand Down Expand Up @@ -124,8 +126,11 @@ def text_to_lm_training_input(
split_fn = functools.partial(
_trim_or_pad_and_batch, max_padding_fraction=max_padding_fraction, pad_id=vocab.pad_id
)
# Only repeat if not already infinite.
if len(ds) != sys.maxsize:
if isinstance(ds, MultiprocessPrefetchIterDataset):
# Dataset types like MultiprocessPrefetchIterDataset have no len() or repeat() function
logging.info("Skipping repeat for ds: %s`", ds)
elif len(ds) != sys.maxsize:
# Only repeat if not already infinite.
ds = ds.repeat(num_epochs=None)
ds = input_grain_text.tokenize(ds, vocab={"text": vocab}, with_eos=True)
ds = input_grain.rekey(ds, key_map={"target_labels": "text"})
Expand Down
25 changes: 25 additions & 0 deletions axlearn/common/input_grain_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ def test_training_lm_processor_infinite_dataset(self):
max_padding_fraction=0.0,
)

@pytest.mark.skipif(
not os.path.exists(t5_sentence_piece_vocab_file), reason="Missing testdata."
)
def test_training_lm_processor_dataset_no_len_func(self):
max_len = 32
vocab = seqio.SentencePieceVocabulary(
sentencepiece_model_file=t5_sentence_piece_vocab_file,
)
examples = [{"text": f"test_str_#{i}", "index": i} for i in range(10)]
ds = prefetch_dataset(
maybe_to_iter_dataset(fake_grain_source(examples)),
multiprocessing_options=grain.MultiprocessingOptions(
num_workers=1,
per_worker_buffer_size=1,
enable_profiling=False,
),
)
ds = text_to_lm_training_input(
ds, # check if prefetch dataset breaks the pipeline
vocab=vocab,
max_len=max_len,
window_size=3,
max_padding_fraction=0.0,
)

@parameterized.parameters(
dict(
expected_batches=[
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/input_grain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def convert_examples(x, rng: np.random.Generator):

ds = range_dataset(start=1, stop=10)
ds = ds.repeat(None).batch(3)
ds = ds.map(convert_examples, seed=123)
ds = ds.random_map(convert_examples, seed=123)
ds = unbatch(maybe_to_iter_dataset(ds))
ds = iter(ds)
self._test_checkpointing(ds)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ orbax = [
]
# Grain input processing. Currently does not support macos.
grain = [
"grain==0.2.1; platform_machine == 'x86_64'",
"grain==0.2.3; platform_machine == 'x86_64'",
]
# Audio dependencies.
audio = [
Expand Down

0 comments on commit bd6d32b

Please sign in to comment.