From 3d5baa87275e7725df19406726832596c3a0059d Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Thu, 6 Feb 2025 19:11:06 -0800 Subject: [PATCH] Make sure dataset options are propagated through `mp_prefetch`. PiperOrigin-RevId: 724163526 --- grain/_src/python/dataset/dataset.py | 6 ++--- .../dataset/transformations/prefetch.py | 23 ++++++++++++++++--- .../dataset/transformations/prefetch_test.py | 20 ++++++++++++++++ 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 925cdca1..6a24a83a 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1228,18 +1228,18 @@ class WithOptionsIterDataset(IterDataset[T]): def __init__(self, parent: IterDataset[T], options: base.DatasetOptions): super().__init__(parent) - self._options = options + self.options = options def __iter__(self) -> DatasetIterator[T]: result = self._parent.__iter__() # The parent iterator options are merged from the entire subtree. Merge # them with the latest options and update the subtree options. - options = self._options.merge(result._ctx.dataset_options) + options = self.options.merge(result._ctx.dataset_options) result._ctx.dataset_options = options return result def __str__(self): - return f"WithOptionsIterDataset(options={self._options})" + return f"WithOptionsIterDataset(options={self.options})" _ConsistentDatasetType = TypeVar( diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 1e19c752..c49e408d 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -423,6 +423,17 @@ def serialize(self) -> bytes: raise e +def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: + result = base.DatasetOptions() + to_visit = [ds] + while to_visit: + parent = to_visit.pop() + if isinstance(parent, dataset.WithOptionsIterDataset): + result = result.merge(parent.options) + to_visit.extend(parent.parents) + return result + + class MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): """Iterator that performs prefetching using a multiprocessing pool.""" @@ -433,6 +444,10 @@ def __init__( ): super().__init__() self._iter_parent = parent + # Since the parent iterator is going to be created in each subprocess, and + # the options are propagated during iterator creation, we need to manually + # propagate them. + self._ctx.dataset_options = _get_dataset_options(parent) self._multiprocessing_options = multiprocessing_options # The underlying iterator producing elements and workers state. self._iterator = None @@ -526,10 +541,12 @@ def _ensure_iterator_initialized(self) -> None: def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: """Creates a `MultiProcessIterator`.""" - - get_element_producer_fn = GetElementProducerFn( - self._state, self._iter_parent + # Apply the latest options to the subprocess dataset. We delay this until + # starting subprocesses because child iterators may update them. + ds = dataset.WithOptionsIterDataset( + self._iter_parent, self._ctx.dataset_options ) + get_element_producer_fn = GetElementProducerFn(self._state, ds) return grain_pool.MultiProcessIterator( get_element_producer_fn, diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 94b7b197..fee62610 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -606,6 +606,26 @@ def make_iter(i): for it in iters: _ = next(it) + def test_options_before_prefetch(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) + ds = ds.to_iter_dataset() + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds = dataset.WithOptionsIterDataset(ds, ds_options) + ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) + ds = ds.filter(lambda x: x > 2) + with self.assertRaises(Exception): + list(ds) + + def test_options_after_prefetch(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) + ds = ds.filter(lambda x: x > 2) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds = dataset.WithOptionsIterDataset(ds, ds_options) + with self.assertRaises(Exception): + list(ds) + class ThreadPrefetchIterDatasetTest(parameterized.TestCase):