diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index fc3245d8..d332d419 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -83,6 +83,7 @@ py_library( srcs = ["stats.py"], srcs_version = "PY3", deps = [ + ":base", "//grain/_src/core:config", "//grain/_src/core:monitoring", "//grain/_src/core:tree", diff --git a/grain/_src/python/dataset/base.py b/grain/_src/python/dataset/base.py index 6dc8a718..4f78d2c9 100644 --- a/grain/_src/python/dataset/base.py +++ b/grain/_src/python/dataset/base.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Base interfaces for working with LazyDataset. +"""Primitives for working with Dataset APIs. -Classes in this module are shared by LazyDataset classes and transformations. +Classes in this module are shared by Dataset implementations as well as public +Dataset API. """ +from __future__ import annotations import abc +import dataclasses +import enum import typing -from typing import Protocol, TypeVar +from typing import Generic, Protocol, TypeVar T = TypeVar("T") @@ -50,3 +54,129 @@ def __len__(self) -> int: @abc.abstractmethod def __getitem__(self, index: int) -> tuple[int, int]: """Returns constituent dataset index and index within this dataset.""" + + +@enum.unique +class ExecutionTrackingMode(enum.Flag): + """Represents different modes for tracking execution statistics. + + Available modes: + DISABLED: + No execution statistics are measured. This mode is the default. + STAGE_TIMING: + The time taken for each transformation stage to execute is measured and + recorded. This recorded time reflects the duration spent within the + specific transformation to return an element, excluding the time spent in + any parent transformations. The recorded time can be retrieved using + `grain.experimental.get_execution_summary` method. + """ + + DISABLED = enum.auto() + STAGE_TIMING = enum.auto() + + +@dataclasses.dataclass(slots=True, frozen=True) +class _Default(Generic[T]): + """Default options value holder.""" + + value: T + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class DatasetOptions: + """Holds options used by dataset transformations. + + Attributes: + filter_warn_threshold_ratio: If the ratio of filtered out elements is above + these thresholds, a warning will be issued. Value `None` disables the + check. The ratio is calculated on non-overlapping windows of 1000 + elements. For instance, with `filter_warn_threshold_ratio=0.9` and 901 + elements out of the first 1000 (or elements 1000...2000) filtered out, a + warning will be issued. + filter_raise_threshold_ratio: If the ratio of filtered out elements is above + these thresholds, an exception will be issued. Value `None` disables the + check. + execution_tracking_mode: The collection of execution statistics like total + processing time taken by each transformation, number of elements produced + etc. can be managed through various modes. If `DISABLED`, no statistics + are collected.If `STAGE_TIMING`, the time it takes to process each + transormation is collected. See `ExecutionTrackingMode` for more details. + """ + + filter_warn_threshold_ratio: float | None | _Default[float] = _Default(0.9) + filter_raise_threshold_ratio: float | None | _Default[None] = _Default(None) + execution_tracking_mode: ( + ExecutionTrackingMode | _Default[ExecutionTrackingMode] + ) = _Default(ExecutionTrackingMode.DISABLED) + + # Internal fields. + + # Names of fields which were set by the user. + _user_set_fields: set[str] = dataclasses.field( + default_factory=set, init=False + ) + + def __post_init__(self): + # Replace default value objects with actual values. + for field in dataclasses.fields(DatasetOptions): + value = getattr(self, field.name) + if isinstance(value, _Default): + super().__setattr__(field.name, value.value) + elif field.init: + self._user_set_fields.add(field.name) + + def merge(self, other: DatasetOptions | None) -> DatasetOptions: + """Merges these options with the other. + + Explicitly set options in `self` take precedence over options in `other`. + + Args: + other: Options to merge. + + Returns: + Merged options. + """ + if other is None: + return self + + merged = {} + for field in dataclasses.fields(DatasetOptions): + if field.name in self._user_set_fields: + merged[field.name] = getattr(self, field.name) + elif field.name in other._user_set_fields: # pylint: disable=protected-access + merged[field.name] = getattr(other, field.name) + return DatasetOptions(**merged) + + +@dataclasses.dataclass(kw_only=True, frozen=True, slots=True) +class MultiprocessingContext: + """Context of the current process as a part of multiprocessing system.""" + + process_index: int = 0 + process_count: int = 1 + + +@dataclasses.dataclass(kw_only=True, slots=True) +class IteratorContext: + """Context shared by all iterators in a dataset. + + The context is mutable and: + - Should be updated only before or during iterator initialization. + - Attributes should only be used after all iterators in the pipeline are + initialized. In practice, this means during pipeline execution with lazy + initialization mechanisms such as `functools.cached_property`. + """ + + # Dataset transformation options. + dataset_options: DatasetOptions = DatasetOptions() + # Multiprocessing context of the worker process running this iterator. + mp_context: MultiprocessingContext = MultiprocessingContext() + + def merge(self, other: IteratorContext) -> None: + """Merges this context with the other in place.""" + self.dataset_options = self.dataset_options.merge(other.dataset_options) + if self.mp_context != other.mp_context: + raise ValueError( + "Cannot merge contexts from different worker processes:" + f" {self.mp_context} vs {other.mp_context}." + ) diff --git a/grain/_src/python/dataset/base_test.py b/grain/_src/python/dataset/base_test.py index b473207c..10ffc53c 100644 --- a/grain/_src/python/dataset/base_test.py +++ b/grain/_src/python/dataset/base_test.py @@ -30,5 +30,77 @@ def test_protocol(self, source_cls): self.assertIsInstance(source_cls, base.RandomAccessDataSource) +class DatasetOptionsTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="no_conflicts", + a=base.DatasetOptions(filter_warn_threshold_ratio=0.1), + b=base.DatasetOptions(filter_raise_threshold_ratio=0.2), + expected=base.DatasetOptions( + filter_warn_threshold_ratio=0.1, + filter_raise_threshold_ratio=0.2, + ), + ), + dict( + testcase_name="all_fields_default", + a=base.DatasetOptions(), + b=base.DatasetOptions( + filter_warn_threshold_ratio=0.4, + filter_raise_threshold_ratio=0.3, + ), + expected=base.DatasetOptions( + filter_warn_threshold_ratio=0.4, + filter_raise_threshold_ratio=0.3, + ), + ), + dict( + testcase_name="field_conflict", + a=base.DatasetOptions(filter_raise_threshold_ratio=0.1), + b=base.DatasetOptions(filter_raise_threshold_ratio=0.2), + expected=base.DatasetOptions( + filter_raise_threshold_ratio=0.1, + ), + ), + ) + def test_merge(self, a, b, expected): + self.assertEqual(a.merge(b), expected) + + +class IteratorContextTest(parameterized.TestCase): + + def test_merge(self): + a = base.IteratorContext( + dataset_options=base.DatasetOptions(filter_warn_threshold_ratio=0.1) + ) + b = base.IteratorContext( + dataset_options=base.DatasetOptions( + filter_warn_threshold_ratio=0.2, filter_raise_threshold_ratio=0.2 + ) + ) + a.merge(b) + self.assertEqual( + a, + base.IteratorContext( + dataset_options=base.DatasetOptions( + filter_warn_threshold_ratio=0.1, + filter_raise_threshold_ratio=0.2, + ) + ), + ) + + def test_merge_with_different_mp_context(self): + a = base.IteratorContext( + mp_context=base.MultiprocessingContext(process_index=0, process_count=1) + ) + b = base.IteratorContext( + mp_context=base.MultiprocessingContext(process_index=1, process_count=2) + ) + with self.assertRaisesRegex( + ValueError, "Cannot merge contexts from different worker processes" + ): + a.merge(b) + + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 68dc6787..9ab0a13e 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -41,7 +41,6 @@ import abc import builtins from collections.abc import Callable, Iterable, Iterator, Sequence -import dataclasses import functools from typing import ( Any, @@ -88,6 +87,7 @@ class _Dataset: _MUTATES_ELEMENT_SPEC = True def __init__(self, parents: Sequence[_Dataset]): + super().__init__() # Seeds a `SeedSequence` used to generate default seeds for all # downstream transformations. Set by `_WithOptions{Map|Iter}Dataset`. self._seed_rng_seed = None @@ -676,7 +676,7 @@ def __iter__(self) -> DatasetIterator[T]: return self.to_iter_dataset().__iter__() def _initialize_stats( - self, execution_tracking_mode: dataset_stats.ExecutionTrackingMode + self, execution_tracking_mode: base.ExecutionTrackingMode ) -> dataset_stats.Stats: """Eagerly initializes the stats object with given execution tracking mode. @@ -711,7 +711,7 @@ def _initialize_stats( @functools.cached_property def _stats(self) -> dataset_stats.Stats: """Returns the Stats object for recording statistics about this dataset.""" - return self._initialize_stats(dataset_stats.ExecutionTrackingMode.DISABLED) + return self._initialize_stats(base.ExecutionTrackingMode.DISABLED) class _IterDatasetMeta(abc.ABCMeta): @@ -1071,31 +1071,24 @@ def __init__( self, parents: DatasetIterator | Sequence[DatasetIterator] = (), ): + super().__init__() if isinstance(parents, DatasetIterator): self._parents = (parents,) else: self._parents = tuple(parents) - # The options are set in `WithOptionsIterDataset`. - self._options: DatasetOptions | None = None - parent_options = [] - for p in self._parents: - # Not all user iterators call super().__init__ and thus don't have the - # options set. - if (p_options := getattr(p, "_options", None)) is not None: - parent_options.append(p_options) - if parent_options: - self._options = functools.reduce(lambda x, y: x.merge(y), parent_options) - - @property - def _options_with_default(self) -> DatasetOptions: - """Returns options for the pipeline including the given iterator. - - WARNING: must be accessed after all iterators in the pipeline have been - initialized. - """ - # TODO: Relax the requirement to access options after all iterators - # in the pipeline have been initialized. - return getattr(self, "_options", None) or DatasetOptions() + if self._parents: + self._ctx: base.IteratorContext = self._parents[0]._ctx + # Merge the context from all parents. + to_visit = list(self._parents[1:]) + for parent in to_visit: + self._ctx.merge(parent._ctx) + # Update the context in the parent iterator trees. + while to_visit: + current = to_visit.pop() + current._ctx = self._ctx + to_visit.extend(current._parents) + else: + self._ctx: base.IteratorContext = base.IteratorContext() @property def _parent(self) -> DatasetIterator: @@ -1145,19 +1138,16 @@ def start_prefetch(self) -> None: @functools.cached_property def _stats(self): """Returns the Stats object for recording statistics about this iterator.""" - # There may be parent `DatasetIterator` nodes introduced by users that did - # not call super init and thus don't have `_stats`. - parents_stats = [] - if hasattr(self, "_parents"): - for p in self._parents: - parents_stats.append(p._stats) # pylint: disable=protected-access + config = dataset_stats.StatsConfig( + name=str(self), + transform_mutates_spec=self._MUTATES_ELEMENT_SPEC, + ) return dataset_stats.make_stats( - dataset_stats.StatsConfig( - name=str(self), - transform_mutates_spec=self._MUTATES_ELEMENT_SPEC, + config, + [p._stats for p in self._parents], # pylint: disable=protected-access + execution_tracking_mode=( + self._ctx.dataset_options.execution_tracking_mode ), - parents_stats, - execution_tracking_mode=self._options_with_default.execution_tracking_mode, ) @@ -1196,80 +1186,6 @@ def __str__(self): return "WithOptionsIterDataset" -@dataclasses.dataclass(slots=True, frozen=True) -class _Default(Generic[T]): - """Default options value holder.""" - - value: T - - -@dataclasses.dataclass(kw_only=True, frozen=True) -class DatasetOptions: - """Holds options used by dataset transformations. - - Attributes: - filter_warn_threshold_ratio: If the ratio of filtered out elements is above - these thresholds, a warning will be issued. Value `None` disables the - check. The ratio is calculated on non-overlapping windows of 1000 - elements. For instance, with `filter_warn_threshold_ratio=0.9` and 901 - elements out of the first 1000 (or elements 1000...2000) filtered out, a - warning will be issued. - filter_raise_threshold_ratio: If the ratio of filtered out elements is above - these thresholds, an exception will be issued. Value `None` disables the - check. - execution_tracking_mode: The collection of execution statistics like total - processing time taken by each transformation, number of elements produced - etc. can be managed through various modes. If `DISABLED`, no statistics - are collected.If `STAGE_TIMING`, the time it takes to process each - transormation is collected. See `ExecutionTrackingMode` for more details. - """ - - filter_warn_threshold_ratio: float | None | _Default[float] = _Default(0.9) - filter_raise_threshold_ratio: float | None | _Default[None] = _Default(None) - execution_tracking_mode: ( - dataset_stats.ExecutionTrackingMode - | _Default[dataset_stats.ExecutionTrackingMode] - ) = _Default(dataset_stats.ExecutionTrackingMode.DISABLED) - - # Internal fields. - - # Names of fields which were set by the user. - _user_set_fields: set[str] = dataclasses.field( - default_factory=set, init=False - ) - - def __post_init__(self): - # Replace default value objects with actual values. - for field in dataclasses.fields(DatasetOptions): - value = getattr(self, field.name) - if isinstance(value, _Default): - super().__setattr__(field.name, value.value) - elif field.init: - self._user_set_fields.add(field.name) - - def merge(self, other: DatasetOptions | None) -> DatasetOptions: - """Merges these options with the other. - - Explicitly set options in `self` take precedence over options in `other`. - - Args: - other: Options to merge. - - Returns: - Merged options. - """ - if other is None: - return self - - merged = {} - for field in dataclasses.fields(DatasetOptions): - if field.name in self._user_set_fields: - merged[field.name] = getattr(self, field.name) - elif field.name in other._user_set_fields: # pylint: disable=protected-access - merged[field.name] = getattr(other, field.name) - return DatasetOptions(**merged) - - class WithOptionsIterDataset(IterDataset[T]): """Applies options to transformations in the pipeline. @@ -1302,7 +1218,7 @@ class WithOptionsIterDataset(IterDataset[T]): ``` """ - def __init__(self, parent: IterDataset[T], options: DatasetOptions): + def __init__(self, parent: IterDataset[T], options: base.DatasetOptions): super().__init__(parent) self._options = options @@ -1310,12 +1226,8 @@ def __iter__(self) -> DatasetIterator[T]: result = self._parent.__iter__() # The parent iterator options are merged from the entire subtree. Merge # them with the latest options and update the subtree options. - options = self._options.merge(result._options) - to_visit = [result] - while to_visit: - current = to_visit.pop() - current._options = options - to_visit.extend(current._parents) + options = self._options.merge(result._ctx.dataset_options) + result._ctx.dataset_options = options return result def __str__(self): diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 888681c8..1c846764 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -814,61 +814,24 @@ def test_unsupported_transform(self): _ = dataset.apply_transformations(ds, TfRandomMapAlwaysAddingOne()) -class DatasetOptionsTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name="no_conflicts", - a=dataset.DatasetOptions(filter_warn_threshold_ratio=0.1), - b=dataset.DatasetOptions(filter_raise_threshold_ratio=0.2), - expected=dataset.DatasetOptions( - filter_warn_threshold_ratio=0.1, - filter_raise_threshold_ratio=0.2, - ), - ), - dict( - testcase_name="all_fields_default", - a=dataset.DatasetOptions(), - b=dataset.DatasetOptions( - filter_warn_threshold_ratio=0.4, - filter_raise_threshold_ratio=0.3, - ), - expected=dataset.DatasetOptions( - filter_warn_threshold_ratio=0.4, - filter_raise_threshold_ratio=0.3, - ), - ), - dict( - testcase_name="field_conflict", - a=dataset.DatasetOptions(filter_raise_threshold_ratio=0.1), - b=dataset.DatasetOptions(filter_raise_threshold_ratio=0.2), - expected=dataset.DatasetOptions( - filter_raise_threshold_ratio=0.1, - ), - ), - ) - def test_merge(self, a, b, expected): - self.assertEqual(a.merge(b), expected) - - class WithOptionsIterDatasetTest(parameterized.TestCase): def _assert_subtree_options_equal( - self, ds: dataset.IterDataset, expected: dataset.DatasetOptions + self, ds: dataset.IterDataset, expected: base.DatasetOptions ): to_check = [ds.__iter__()] while to_check: next_it = to_check.pop() self.assertEqual( - next_it._options, + next_it._ctx.dataset_options, expected, - f"Options are not equal for {next_it}; actual: {next_it._options}," - f" expected: {expected}.", + f"Options are not equal for {next_it}; actual:" + f" {next_it._ctx.dataset_options}, expected: {expected}.", ) to_check.extend(next_it._parents) def test_propagates_options_in_linear_pipeline(self): - actual_options = dataset.DatasetOptions( + actual_options = base.DatasetOptions( filter_warn_threshold_ratio=0.1, filter_raise_threshold_ratio=0.2, ) @@ -884,7 +847,7 @@ def test_propagates_options_in_linear_pipeline(self): self._assert_subtree_options_equal(ds, actual_options) def test_propagates_options_in_tree_pipeline(self): - actual_options = dataset.DatasetOptions( + actual_options = base.DatasetOptions( filter_warn_threshold_ratio=0.1, filter_raise_threshold_ratio=0.2, ) @@ -911,17 +874,17 @@ def test_conflicting_options(self): .batch(batch_size=2) .filter(lambda x: True) ) - options1 = dataset.DatasetOptions( + options1 = base.DatasetOptions( filter_warn_threshold_ratio=0.1, filter_raise_threshold_ratio=0.2, ) ds = dataset.WithOptionsIterDataset(ds, options1) ds = ds.map(lambda x: x).filter(lambda x: True) - options2 = dataset.DatasetOptions(filter_raise_threshold_ratio=0.4) + options2 = base.DatasetOptions(filter_raise_threshold_ratio=0.4) ds = dataset.WithOptionsIterDataset(ds, options2) self._assert_subtree_options_equal( ds, - dataset.DatasetOptions( + base.DatasetOptions( filter_warn_threshold_ratio=0.1, filter_raise_threshold_ratio=0.4, ), diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index faafe618..c7fec32a 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -19,7 +19,6 @@ from collections.abc import Sequence import contextlib import dataclasses -import enum import pprint import sys import threading @@ -31,6 +30,7 @@ from grain._src.core import config as grain_config from grain._src.core import monitoring as grain_monitoring from grain._src.core import tree +from grain._src.python.dataset import base from grain._src.core import monitoring @@ -170,25 +170,6 @@ def _pretty_format_summary( return table.get_pretty_wrapped_summary() # pylint: disable=protected-access -@enum.unique -class ExecutionTrackingMode(enum.Flag): - """Represents different modes for tracking execution statistics. - - Available modes: - DISABLED: - No execution statistics are measured. This mode is the default. - STAGE_TIMING: - The time taken for each transformation stage to execute is measured and - recorded. This recorded time reflects the duration spent within the - specific transformation to return an element, excluding the time spent in - any parent transformations. The recorded time can be retrieved using - `grain.experimental.get_execution_summary` method. - """ - - DISABLED = enum.auto() - STAGE_TIMING = enum.auto() - - class _Table: """Table class for pretty printing tabular data.""" @@ -616,7 +597,9 @@ def report(self): def make_stats( config: StatsConfig, parents: Sequence[Stats], - execution_tracking_mode: ExecutionTrackingMode = ExecutionTrackingMode.DISABLED, + execution_tracking_mode: base.ExecutionTrackingMode = ( + base.ExecutionTrackingMode.DISABLED + ), ) -> Stats: """Produces statistics instance according to the current execution mode.""" return _NoopStats(config, parents=parents) diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index b77da553..d0e050f1 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -33,6 +33,7 @@ py_test( deps = [ "//grain/_src/core:transforms", "//grain/_src/python/dataset", + "//grain/_src/python/dataset:base", ], ) @@ -45,6 +46,7 @@ py_test( "//grain/_src/core:transforms", "//grain/_src/python:options", "//grain/_src/python/dataset", + "//grain/_src/python/dataset:base", ], ) diff --git a/grain/_src/python/dataset/transformations/filter.py b/grain/_src/python/dataset/transformations/filter.py index 0cec230e..e961ab05 100644 --- a/grain/_src/python/dataset/transformations/filter.py +++ b/grain/_src/python/dataset/transformations/filter.py @@ -143,8 +143,8 @@ def __init__( def _threshold_checker(self): return FilterThresholdChecker( transform_name=str(self), - warn_threshold=self._options_with_default.filter_warn_threshold_ratio, - raise_threshold=self._options_with_default.filter_raise_threshold_ratio, + warn_threshold=self._ctx.dataset_options.filter_warn_threshold_ratio, + raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio, ) def __next__(self): diff --git a/grain/_src/python/dataset/transformations/filter_test.py b/grain/_src/python/dataset/transformations/filter_test.py index 7da37bb1..a704c315 100644 --- a/grain/_src/python/dataset/transformations/filter_test.py +++ b/grain/_src/python/dataset/transformations/filter_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest from grain._src.core import transforms +from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import filter as filter_dataset @@ -151,7 +152,7 @@ def test_filter_all_elements_raises(self): .filter(FilterAllElements()) ) ds = dataset.WithOptionsIterDataset( - ds, dataset.DatasetOptions(filter_raise_threshold_ratio=0.999) + ds, base.DatasetOptions(filter_raise_threshold_ratio=0.999) ) with self.assertRaisesRegex( ValueError, @@ -168,7 +169,7 @@ def setUp(self): filter_dataset._WARN_FILTERED_INTERVAL_SEC = 0.0 def test_validates(self): - default_options = dataset.DatasetOptions() + default_options = base.DatasetOptions() v = filter_dataset.FilterThresholdChecker( "test", default_options.filter_warn_threshold_ratio, @@ -179,7 +180,7 @@ def test_validates(self): v.check(p) def test_warns(self): - default_options = dataset.DatasetOptions() + default_options = base.DatasetOptions() v = filter_dataset.FilterThresholdChecker( "test", default_options.filter_warn_threshold_ratio, diff --git a/grain/_src/python/dataset/transformations/map.py b/grain/_src/python/dataset/transformations/map.py index bd1ba722..69c5aba4 100644 --- a/grain/_src/python/dataset/transformations/map.py +++ b/grain/_src/python/dataset/transformations/map.py @@ -20,7 +20,6 @@ from absl import logging from grain._src.core import transforms from grain._src.python.dataset import dataset -from grain._src.python.dataset.transformations import prefetch import numpy as np @@ -238,8 +237,8 @@ def __next__(self): # execution. The actual index value doesn't matter as long as it is # unique for each process. index_for_rng = ( - prefetch.worker_process_index - + self._index_for_rng * prefetch.worker_process_count + self._ctx.mp_context.process_index + + self._index_for_rng * self._ctx.mp_context.process_count ) _reset_rng_state(self._rng, op_seed=0, index=index_for_rng) element = self._map_fn(element, self._rng) diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 2bd5fc2e..1b3b8d09 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -34,6 +34,7 @@ from grain._src.python import grain_pool from grain._src.python import options as grain_options from grain._src.python import shared_memory_array +from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import filter as filter_dataset @@ -41,14 +42,6 @@ T = TypeVar("T") -# Index of the current worker process and total number of processes. If used -# before multiprocess prefetch, must only be used during or after iterator -# initialization. -# TODO: Introduce context shared by all iterators and put these -# variables there. -worker_process_index = 0 -worker_process_count = 1 - @typing.runtime_checkable class SupportsInPlaceSlicing(Protocol): @@ -116,7 +109,7 @@ def __init__( @functools.cached_property def _stats(self): - execution_tracking_mode = self._options_with_default.execution_tracking_mode + execution_tracking_mode = self._ctx.dataset_options.execution_tracking_mode parent_stats = self._map_parent._initialize_stats( # pylint: disable=protected-access execution_tracking_mode ) @@ -136,8 +129,8 @@ def _threshold_checker(self): # here. The validator helps to detect if we discard too many elements. return filter_dataset.FilterThresholdChecker( transform_name=str(self), - warn_threshold=self._options_with_default.filter_warn_threshold_ratio, - raise_threshold=self._options_with_default.filter_raise_threshold_ratio, + warn_threshold=self._ctx.dataset_options.filter_warn_threshold_ratio, + raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio, ) def __next__(self) -> T: @@ -382,12 +375,12 @@ def __init__( def __call__( self, *, worker_index: int, worker_count: int ) -> Iterator[tuple[T, Optional[dict[str, Any]]]]: - global worker_process_index, worker_process_count - worker_process_index = worker_index - worker_process_count = worker_count if worker_count > 1: _set_slice(self._ds, slice(worker_index, None, worker_count)) it = self._ds.__iter__() + it._ctx.mp_context = base.MultiprocessingContext( + process_index=worker_index, process_count=worker_count + ) # Recover from the last recorded state for the given worker. worker_state = self._state[_WORKERS_STATE][str(worker_index)] if worker_state is not None: diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 9bcd1eff..d5f1a555 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -24,6 +24,7 @@ from grain._src.core import transforms import multiprocessing as mp from grain._src.python import options +from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import filter as filter_lazy_dataset from grain._src.python.dataset.transformations import map as map_lazy_dataset @@ -194,7 +195,7 @@ def test_filter_all_elements_raises(self): .filter(FilterAllElements()) .to_iter_dataset() ) - ds_options = dataset.DatasetOptions(filter_raise_threshold_ratio=0.9) + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.9) ds = dataset.WithOptionsIterDataset(ds, ds_options) with self.assertRaisesRegex( ValueError, @@ -211,7 +212,7 @@ def test_filter_all_elements_doesnt_raise_with_allow_nones(self): .filter(FilterAllElements()) .to_iter_dataset(allow_nones=True) ) - ds_options = dataset.DatasetOptions(filter_raise_threshold_ratio=0.9) + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.9) ds = dataset.WithOptionsIterDataset(ds, ds_options) self.assertEqual(list(ds), [None] * 1000) diff --git a/grain/_src/python/dataset/visualize.py b/grain/_src/python/dataset/visualize.py index 79b8102e..10633031 100644 --- a/grain/_src/python/dataset/visualize.py +++ b/grain/_src/python/dataset/visualize.py @@ -174,20 +174,20 @@ def __init__( mock_output: Value to use as the output of the parent iterator. If None, the actual iterator output will be used. """ - self._parent_iter = parent_iter + super().__init__(parent_iter) self._spec_update_fn = spec_update_fn self._mock_output = mock_output def __next__(self) -> T: - result = self._mock_output or self._parent_iter.__next__() + result = self._mock_output or self._parent.__next__() self._spec_update_fn(tree.spec_like(result)) return result def set_state(self, state: dict[str, Any]) -> None: - self._parent_iter.set_state(state) + self._parent.set_state(state) def get_state(self) -> dict[str, Any]: - return self._parent_iter.get_state() + return self._parent.get_state() def _build_visualization_from_tracked_spec( diff --git a/grain/python_experimental.py b/grain/python_experimental.py index 10ec75f7..96602c3c 100644 --- a/grain/python_experimental.py +++ b/grain/python_experimental.py @@ -30,12 +30,14 @@ del epy +from ._src.python.dataset.base import ( + DatasetOptions, + ExecutionTrackingMode, +) from ._src.python.dataset.dataset import ( apply_transformations, - DatasetOptions, WithOptionsIterDataset, ) -from ._src.python.dataset.stats import ExecutionTrackingMode from ._src.python.dataset.transformations.flatmap import ( FlatMapMapDataset, FlatMapIterDataset,