Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure dataset options are propagated through mp_prefetch. #726

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,18 +1228,18 @@ class WithOptionsIterDataset(IterDataset[T]):

def __init__(self, parent: IterDataset[T], options: base.DatasetOptions):
super().__init__(parent)
self._options = options
self.options = options

def __iter__(self) -> DatasetIterator[T]:
result = self._parent.__iter__()
# The parent iterator options are merged from the entire subtree. Merge
# them with the latest options and update the subtree options.
options = self._options.merge(result._ctx.dataset_options)
options = self.options.merge(result._ctx.dataset_options)
result._ctx.dataset_options = options
return result

def __str__(self):
return f"WithOptionsIterDataset(options={self._options})"
return f"WithOptionsIterDataset(options={self.options})"


_ConsistentDatasetType = TypeVar(
Expand Down
23 changes: 20 additions & 3 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,17 @@ def serialize(self) -> bytes:
raise e


def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions:
result = base.DatasetOptions()
to_visit = [ds]
while to_visit:
parent = to_visit.pop()
if isinstance(parent, dataset.WithOptionsIterDataset):
result = result.merge(parent.options)
to_visit.extend(parent.parents)
return result


class MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]):
"""Iterator that performs prefetching using a multiprocessing pool."""

Expand All @@ -433,6 +444,10 @@ def __init__(
):
super().__init__()
self._iter_parent = parent
# Since the parent iterator is going to be created in each subprocess, and
# the options are propagated during iterator creation, we need to manually
# propagate them.
self._ctx.dataset_options = _get_dataset_options(parent)
self._multiprocessing_options = multiprocessing_options
# The underlying iterator producing elements and workers state.
self._iterator = None
Expand Down Expand Up @@ -526,10 +541,12 @@ def _ensure_iterator_initialized(self) -> None:

def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]:
"""Creates a `MultiProcessIterator`."""

get_element_producer_fn = GetElementProducerFn(
self._state, self._iter_parent
# Apply the latest options to the subprocess dataset. We delay this until
# starting subprocesses because child iterators may update them.
ds = dataset.WithOptionsIterDataset(
self._iter_parent, self._ctx.dataset_options
)
get_element_producer_fn = GetElementProducerFn(self._state, ds)

return grain_pool.MultiProcessIterator(
get_element_producer_fn,
Expand Down
20 changes: 20 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,26 @@ def make_iter(i):
for it in iters:
_ = next(it)

def test_options_before_prefetch(self):
ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000)
ds = ds.to_iter_dataset()
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds = dataset.WithOptionsIterDataset(ds, ds_options)
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1))
ds = ds.filter(lambda x: x > 2)
with self.assertRaises(Exception):
list(ds)

def test_options_after_prefetch(self):
ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000)
ds = ds.filter(lambda x: x > 2)
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1))
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds = dataset.WithOptionsIterDataset(ds, ds_options)
with self.assertRaises(Exception):
list(ds)


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):

Expand Down