Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce IteratorContext shared by all DatasetIterators in a pipeline. #678

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ py_library(
srcs = ["stats.py"],
srcs_version = "PY3",
deps = [
":base",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:tree",
Expand Down
136 changes: 133 additions & 3 deletions grain/_src/python/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base interfaces for working with LazyDataset.
"""Primitives for working with Dataset APIs.

Classes in this module are shared by LazyDataset classes and transformations.
Classes in this module are shared by Dataset implementations as well as public
Dataset API.
"""
from __future__ import annotations

import abc
import dataclasses
import enum
import typing
from typing import Protocol, TypeVar
from typing import Generic, Protocol, TypeVar


T = TypeVar("T")
Expand Down Expand Up @@ -50,3 +54,129 @@ def __len__(self) -> int:
@abc.abstractmethod
def __getitem__(self, index: int) -> tuple[int, int]:
"""Returns constituent dataset index and index within this dataset."""


@enum.unique
class ExecutionTrackingMode(enum.Flag):
"""Represents different modes for tracking execution statistics.

Available modes:
DISABLED:
No execution statistics are measured. This mode is the default.
STAGE_TIMING:
The time taken for each transformation stage to execute is measured and
recorded. This recorded time reflects the duration spent within the
specific transformation to return an element, excluding the time spent in
any parent transformations. The recorded time can be retrieved using
`grain.experimental.get_execution_summary` method.
"""

DISABLED = enum.auto()
STAGE_TIMING = enum.auto()


@dataclasses.dataclass(slots=True, frozen=True)
class _Default(Generic[T]):
"""Default options value holder."""

value: T


@dataclasses.dataclass(kw_only=True, frozen=True)
class DatasetOptions:
"""Holds options used by dataset transformations.

Attributes:
filter_warn_threshold_ratio: If the ratio of filtered out elements is above
these thresholds, a warning will be issued. Value `None` disables the
check. The ratio is calculated on non-overlapping windows of 1000
elements. For instance, with `filter_warn_threshold_ratio=0.9` and 901
elements out of the first 1000 (or elements 1000...2000) filtered out, a
warning will be issued.
filter_raise_threshold_ratio: If the ratio of filtered out elements is above
these thresholds, an exception will be issued. Value `None` disables the
check.
execution_tracking_mode: The collection of execution statistics like total
processing time taken by each transformation, number of elements produced
etc. can be managed through various modes. If `DISABLED`, no statistics
are collected.If `STAGE_TIMING`, the time it takes to process each
transormation is collected. See `ExecutionTrackingMode` for more details.
"""

filter_warn_threshold_ratio: float | None | _Default[float] = _Default(0.9)
filter_raise_threshold_ratio: float | None | _Default[None] = _Default(None)
execution_tracking_mode: (
ExecutionTrackingMode | _Default[ExecutionTrackingMode]
) = _Default(ExecutionTrackingMode.DISABLED)

# Internal fields.

# Names of fields which were set by the user.
_user_set_fields: set[str] = dataclasses.field(
default_factory=set, init=False
)

def __post_init__(self):
# Replace default value objects with actual values.
for field in dataclasses.fields(DatasetOptions):
value = getattr(self, field.name)
if isinstance(value, _Default):
super().__setattr__(field.name, value.value)
elif field.init:
self._user_set_fields.add(field.name)

def merge(self, other: DatasetOptions | None) -> DatasetOptions:
"""Merges these options with the other.

Explicitly set options in `self` take precedence over options in `other`.

Args:
other: Options to merge.

Returns:
Merged options.
"""
if other is None:
return self

merged = {}
for field in dataclasses.fields(DatasetOptions):
if field.name in self._user_set_fields:
merged[field.name] = getattr(self, field.name)
elif field.name in other._user_set_fields: # pylint: disable=protected-access
merged[field.name] = getattr(other, field.name)
return DatasetOptions(**merged)


@dataclasses.dataclass(kw_only=True, frozen=True, slots=True)
class MultiprocessingContext:
"""Context of the current process as a part of multiprocessing system."""

process_index: int = 0
process_count: int = 1


@dataclasses.dataclass(kw_only=True, slots=True)
class IteratorContext:
"""Context shared by all iterators in a dataset.

The context is mutable and:
- Should be updated only before or during iterator initialization.
- Attributes should only be used after all iterators in the pipeline are
initialized. In practice, this means during pipeline execution with lazy
initialization mechanisms such as `functools.cached_property`.
"""

# Dataset transformation options.
dataset_options: DatasetOptions = DatasetOptions()
# Multiprocessing context of the worker process running this iterator.
mp_context: MultiprocessingContext = MultiprocessingContext()

def merge(self, other: IteratorContext) -> None:
"""Merges this context with the other in place."""
self.dataset_options = self.dataset_options.merge(other.dataset_options)
if self.mp_context != other.mp_context:
raise ValueError(
"Cannot merge contexts from different worker processes:"
f" {self.mp_context} vs {other.mp_context}."
)
72 changes: 72 additions & 0 deletions grain/_src/python/dataset/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,77 @@ def test_protocol(self, source_cls):
self.assertIsInstance(source_cls, base.RandomAccessDataSource)


class DatasetOptionsTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name="no_conflicts",
a=base.DatasetOptions(filter_warn_threshold_ratio=0.1),
b=base.DatasetOptions(filter_raise_threshold_ratio=0.2),
expected=base.DatasetOptions(
filter_warn_threshold_ratio=0.1,
filter_raise_threshold_ratio=0.2,
),
),
dict(
testcase_name="all_fields_default",
a=base.DatasetOptions(),
b=base.DatasetOptions(
filter_warn_threshold_ratio=0.4,
filter_raise_threshold_ratio=0.3,
),
expected=base.DatasetOptions(
filter_warn_threshold_ratio=0.4,
filter_raise_threshold_ratio=0.3,
),
),
dict(
testcase_name="field_conflict",
a=base.DatasetOptions(filter_raise_threshold_ratio=0.1),
b=base.DatasetOptions(filter_raise_threshold_ratio=0.2),
expected=base.DatasetOptions(
filter_raise_threshold_ratio=0.1,
),
),
)
def test_merge(self, a, b, expected):
self.assertEqual(a.merge(b), expected)


class IteratorContextTest(parameterized.TestCase):

def test_merge(self):
a = base.IteratorContext(
dataset_options=base.DatasetOptions(filter_warn_threshold_ratio=0.1)
)
b = base.IteratorContext(
dataset_options=base.DatasetOptions(
filter_warn_threshold_ratio=0.2, filter_raise_threshold_ratio=0.2
)
)
a.merge(b)
self.assertEqual(
a,
base.IteratorContext(
dataset_options=base.DatasetOptions(
filter_warn_threshold_ratio=0.1,
filter_raise_threshold_ratio=0.2,
)
),
)

def test_merge_with_different_mp_context(self):
a = base.IteratorContext(
mp_context=base.MultiprocessingContext(process_index=0, process_count=1)
)
b = base.IteratorContext(
mp_context=base.MultiprocessingContext(process_index=1, process_count=2)
)
with self.assertRaisesRegex(
ValueError, "Cannot merge contexts from different worker processes"
):
a.merge(b)


if __name__ == "__main__":
absltest.main()
Loading
Loading