Skip to content

Commit

Permalink
Introduce IteratorContext shared by all DatasetIterators in a pip…
Browse files Browse the repository at this point in the history
…eline.

PiperOrigin-RevId: 712921067
  • Loading branch information
iindyk authored and copybara-github committed Jan 7, 2025
1 parent 47099eb commit aef53f4
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 216 deletions.
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ py_library(
srcs = ["stats.py"],
srcs_version = "PY3",
deps = [
":base",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:tree",
Expand Down
136 changes: 133 additions & 3 deletions grain/_src/python/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base interfaces for working with LazyDataset.
"""Primitives for working with Dataset APIs.
Classes in this module are shared by LazyDataset classes and transformations.
Classes in this module are shared by Dataset implementations as well as public
Dataset API.
"""
from __future__ import annotations

import abc
import dataclasses
import enum
import typing
from typing import Protocol, TypeVar
from typing import Generic, Protocol, TypeVar


T = TypeVar("T")
Expand Down Expand Up @@ -50,3 +54,129 @@ def __len__(self) -> int:
@abc.abstractmethod
def __getitem__(self, index: int) -> tuple[int, int]:
"""Returns constituent dataset index and index within this dataset."""


@enum.unique
class ExecutionTrackingMode(enum.Flag):
"""Represents different modes for tracking execution statistics.
Available modes:
DISABLED:
No execution statistics are measured. This mode is the default.
STAGE_TIMING:
The time taken for each transformation stage to execute is measured and
recorded. This recorded time reflects the duration spent within the
specific transformation to return an element, excluding the time spent in
any parent transformations. The recorded time can be retrieved using
`grain.experimental.get_execution_summary` method.
"""

DISABLED = enum.auto()
STAGE_TIMING = enum.auto()


@dataclasses.dataclass(slots=True, frozen=True)
class _Default(Generic[T]):
"""Default options value holder."""

value: T


@dataclasses.dataclass(kw_only=True, frozen=True)
class DatasetOptions:
"""Holds options used by dataset transformations.
Attributes:
filter_warn_threshold_ratio: If the ratio of filtered out elements is above
these thresholds, a warning will be issued. Value `None` disables the
check. The ratio is calculated on non-overlapping windows of 1000
elements. For instance, with `filter_warn_threshold_ratio=0.9` and 901
elements out of the first 1000 (or elements 1000...2000) filtered out, a
warning will be issued.
filter_raise_threshold_ratio: If the ratio of filtered out elements is above
these thresholds, an exception will be issued. Value `None` disables the
check.
execution_tracking_mode: The collection of execution statistics like total
processing time taken by each transformation, number of elements produced
etc. can be managed through various modes. If `DISABLED`, no statistics
are collected.If `STAGE_TIMING`, the time it takes to process each
transormation is collected. See `ExecutionTrackingMode` for more details.
"""

filter_warn_threshold_ratio: float | None | _Default[float] = _Default(0.9)
filter_raise_threshold_ratio: float | None | _Default[None] = _Default(None)
execution_tracking_mode: (
ExecutionTrackingMode | _Default[ExecutionTrackingMode]
) = _Default(ExecutionTrackingMode.DISABLED)

# Internal fields.

# Names of fields which were set by the user.
_user_set_fields: set[str] = dataclasses.field(
default_factory=set, init=False
)

def __post_init__(self):
# Replace default value objects with actual values.
for field in dataclasses.fields(DatasetOptions):
value = getattr(self, field.name)
if isinstance(value, _Default):
super().__setattr__(field.name, value.value)
elif field.init:
self._user_set_fields.add(field.name)

def merge(self, other: DatasetOptions | None) -> DatasetOptions:
"""Merges these options with the other.
Explicitly set options in `self` take precedence over options in `other`.
Args:
other: Options to merge.
Returns:
Merged options.
"""
if other is None:
return self

merged = {}
for field in dataclasses.fields(DatasetOptions):
if field.name in self._user_set_fields:
merged[field.name] = getattr(self, field.name)
elif field.name in other._user_set_fields: # pylint: disable=protected-access
merged[field.name] = getattr(other, field.name)
return DatasetOptions(**merged)


@dataclasses.dataclass(kw_only=True, frozen=True, slots=True)
class MultiprocessingContext:
"""Context of the current process as a part of multiprocessing system."""

process_index: int = 0
process_count: int = 1


@dataclasses.dataclass(kw_only=True, slots=True)
class IteratorContext:
"""Context shared by all iterators in a dataset.
The context is mutable and:
- Should be updated only before or during iterator initialization.
- Attributes should only be used after all iterators in the pipeline are
initialized. In practice, this means during pipeline execution with lazy
initialization mechanisms such as `functools.cached_property`.
"""

# Dataset transformation options.
dataset_options: DatasetOptions = DatasetOptions()
# Multiprocessing context of the worker process running this iterator.
mp_context: MultiprocessingContext = MultiprocessingContext()

def merge(self, other: IteratorContext) -> None:
"""Merges this context with the other in place."""
self.dataset_options = self.dataset_options.merge(other.dataset_options)
if self.mp_context != other.mp_context:
raise ValueError(
"Cannot merge contexts from different worker processes:"
f" {self.mp_context} vs {other.mp_context}."
)
72 changes: 72 additions & 0 deletions grain/_src/python/dataset/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,77 @@ def test_protocol(self, source_cls):
self.assertIsInstance(source_cls, base.RandomAccessDataSource)


class DatasetOptionsTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name="no_conflicts",
a=base.DatasetOptions(filter_warn_threshold_ratio=0.1),
b=base.DatasetOptions(filter_raise_threshold_ratio=0.2),
expected=base.DatasetOptions(
filter_warn_threshold_ratio=0.1,
filter_raise_threshold_ratio=0.2,
),
),
dict(
testcase_name="all_fields_default",
a=base.DatasetOptions(),
b=base.DatasetOptions(
filter_warn_threshold_ratio=0.4,
filter_raise_threshold_ratio=0.3,
),
expected=base.DatasetOptions(
filter_warn_threshold_ratio=0.4,
filter_raise_threshold_ratio=0.3,
),
),
dict(
testcase_name="field_conflict",
a=base.DatasetOptions(filter_raise_threshold_ratio=0.1),
b=base.DatasetOptions(filter_raise_threshold_ratio=0.2),
expected=base.DatasetOptions(
filter_raise_threshold_ratio=0.1,
),
),
)
def test_merge(self, a, b, expected):
self.assertEqual(a.merge(b), expected)


class IteratorContextTest(parameterized.TestCase):

def test_merge(self):
a = base.IteratorContext(
dataset_options=base.DatasetOptions(filter_warn_threshold_ratio=0.1)
)
b = base.IteratorContext(
dataset_options=base.DatasetOptions(
filter_warn_threshold_ratio=0.2, filter_raise_threshold_ratio=0.2
)
)
a.merge(b)
self.assertEqual(
a,
base.IteratorContext(
dataset_options=base.DatasetOptions(
filter_warn_threshold_ratio=0.1,
filter_raise_threshold_ratio=0.2,
)
),
)

def test_merge_with_different_mp_context(self):
a = base.IteratorContext(
mp_context=base.MultiprocessingContext(process_index=0, process_count=1)
)
b = base.IteratorContext(
mp_context=base.MultiprocessingContext(process_index=1, process_count=2)
)
with self.assertRaisesRegex(
ValueError, "Cannot merge contexts from different worker processes"
):
a.merge(b)


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit aef53f4

Please sign in to comment.