Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #659

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,21 +675,43 @@ def to_iter_dataset(
def __iter__(self) -> DatasetIterator[T]:
return self.to_iter_dataset().__iter__()

@functools.cached_property
def _stats(self) -> dataset_stats.Stats:
"""Returns the Stats object for recording statistics about this dataset."""
def _initialize_stats(
self, execution_tracking_mode: dataset_stats.ExecutionTrackingMode
) -> dataset_stats.Stats:
"""Eagerly initializes the stats object with given execution tracking mode.

Sets the `_stats` attribute with specified execution tracking mode,
bypassing the `_stats` cached property.
This is beneficial when we want to initialize the stats object eagerly from
PrefetchDatasetIterator, using the appropriate execution tracking mode from
the grain options.

Args:
execution_tracking_mode: The execution tracking mode to use for the stats
object.

Returns:
The initialized stats object.
"""
# There may be parent `MapDataset` nodes introduced by users that did not
# call super init and thus don't have `_parents`.
parents_stats = []
if hasattr(self, "_parents"):
for p in self._parents:
parents_stats.append(p._stats) # pylint: disable=protected-access
return dataset_stats.make_stats(
parents_stats.append(p._initialize_stats(execution_tracking_mode)) # pylint: disable=protected-access
self._stats = dataset_stats.make_stats(
dataset_stats.StatsConfig(
name=str(self), transform_mutates_spec=self._MUTATES_ELEMENT_SPEC
),
parents_stats,
execution_tracking_mode=execution_tracking_mode,
)
return self._stats

@functools.cached_property
def _stats(self) -> dataset_stats.Stats:
"""Returns the Stats object for recording statistics about this dataset."""
return self._initialize_stats(dataset_stats.ExecutionTrackingMode.DISABLED)


class _IterDatasetMeta(abc.ABCMeta):
Expand Down Expand Up @@ -1073,7 +1095,7 @@ def _options_with_default(self) -> DatasetOptions:
"""
# TODO: Relax the requirement to access options after all iterators
# in the pipeline have been initialized.
return self._options or DatasetOptions()
return getattr(self, "_options", None) or DatasetOptions()

@property
def _parent(self) -> DatasetIterator:
Expand Down Expand Up @@ -1131,9 +1153,11 @@ def _stats(self):
parents_stats.append(p._stats) # pylint: disable=protected-access
return dataset_stats.make_stats(
dataset_stats.StatsConfig(
name=str(self), transform_mutates_spec=self._MUTATES_ELEMENT_SPEC
name=str(self),
transform_mutates_spec=self._MUTATES_ELEMENT_SPEC,
),
parents_stats,
execution_tracking_mode=self._options_with_default.execution_tracking_mode,
)


Expand Down Expand Up @@ -1181,17 +1205,31 @@ class _Default(Generic[T]):

@dataclasses.dataclass(kw_only=True, frozen=True)
class DatasetOptions:
"""Holds options used by dataset transformations."""

# If the ratio of filtered out elements is above these thresholds, a warning
# or an exception will be issued, respectively. Value `None` disables the
# check.
# The ratio is calculated on non-overlapping windows of 1000 elements. For
# instancce, with `filter_warn_threshold_ratio=0.9` and 901 elements out of
# the first 1000 (or elements 1000...2000) filtered out, a warning will be
# issued.
"""Holds options used by dataset transformations.

Attributes:
filter_warn_threshold_ratio: If the ratio of filtered out elements is above
these thresholds, a warning will be issued. Value `None` disables the
check. The ratio is calculated on non-overlapping windows of 1000
elements. For instance, with `filter_warn_threshold_ratio=0.9` and 901
elements out of the first 1000 (or elements 1000...2000) filtered out, a
warning will be issued.
filter_raise_threshold_ratio: If the ratio of filtered out elements is above
these thresholds, an exception will be issued. Value `None` disables the
check.
execution_tracking_mode: The collection of execution statistics like total
processing time taken by each transformation, number of elements produced
etc. can be managed through various modes. If `DISABLED`, no statistics
are collected.If `STAGE_TIMING`, the time it takes to process each
transormation is collected. See `ExecutionTrackingMode` for more details.
"""

filter_warn_threshold_ratio: float | None | _Default[float] = _Default(0.9)
filter_raise_threshold_ratio: float | None | _Default[None] = _Default(None)
execution_tracking_mode: (
dataset_stats.ExecutionTrackingMode
| _Default[dataset_stats.ExecutionTrackingMode]
) = _Default(dataset_stats.ExecutionTrackingMode.DISABLED)

# Internal fields.

Expand Down
35 changes: 30 additions & 5 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
import contextlib
import dataclasses
import enum
import pprint
import sys
import threading
Expand Down Expand Up @@ -52,7 +53,7 @@
T = TypeVar("T")
# Time between two consecutive monitoring reports.
_REPORTING_PERIOD_SEC = 10
_LOG_EXECTION_SUMMARY_PERIOD_SEC = 60
_LOG_EXECUTION_SUMMARY_PERIOD_SEC = 60
# Stop reporting if there has been no statistics updates for this long.
_REPORTING_TIMEOUT_SEC = 120

Expand Down Expand Up @@ -169,6 +170,25 @@ def _pretty_format_summary(
return table.get_pretty_wrapped_summary() # pylint: disable=protected-access


@enum.unique
class ExecutionTrackingMode(enum.Flag):
"""Represents different modes for tracking execution statistics.

Available modes:
DISABLED:
No execution statistics are measured. This mode is the default.
STAGE_TIMING:
The time taken for each transformation stage to execute is measured and
recorded. This recorded time reflects the duration spent within the
specific transformation to return an element, excluding the time spent in
any parent transformations. The recorded time can be retrieved using
`grain.experimental.get_execution_summary` method.
"""

DISABLED = enum.auto()
STAGE_TIMING = enum.auto()


class _Table:
"""Table class for pretty printing tabular data."""

Expand Down Expand Up @@ -272,6 +292,8 @@ class StatsConfig:
# Whether this transformation mutates the element spec. This is used to
# determine element spec of the current transformation.
transform_mutates_spec: bool = True
# Whether to log the execution summary.
log_summary: bool = False


class Stats(abc.ABC):
Expand Down Expand Up @@ -495,7 +517,7 @@ def _reporting_loop(self):
def _logging_execution_summary_loop(self):
"""Logs the execution summary periodically."""
while self._should_report():
time.sleep(_LOG_EXECTION_SUMMARY_PERIOD_SEC)
time.sleep(_LOG_EXECUTION_SUMMARY_PERIOD_SEC)
# A node can be marked as non-output after the corresponding
# transformation started processing elements -- we do not control the
# initialization time.
Expand Down Expand Up @@ -565,8 +587,7 @@ def record_self_time(self, offset_ns: int = 0):
target=self._reporting_loop, daemon=True
)
self._reporting_thread.start()

if self._logging_thread is None:
if self._config.log_summary and self._logging_thread is None:
with self._logging_thread_init_lock:
if self._logging_thread is None:
self._logging_thread = threading.Thread(
Expand All @@ -592,6 +613,10 @@ def report(self):
p.report()


def make_stats(config: StatsConfig, parents: Sequence[Stats]) -> Stats:
def make_stats(
config: StatsConfig,
parents: Sequence[Stats],
execution_tracking_mode: ExecutionTrackingMode = ExecutionTrackingMode.DISABLED,
) -> Stats:
"""Produces statistics instance according to the current execution mode."""
return _NoopStats(config, parents=parents)
10 changes: 8 additions & 2 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,18 @@ def __init__(

@functools.cached_property
def _stats(self):
execution_tracking_mode = self._options_with_default.execution_tracking_mode
parent_stats = self._map_parent._initialize_stats( # pylint: disable=protected-access
execution_tracking_mode
)
# Connect to `MapDataset` parent stats.
return dataset_stats.make_stats(
dataset_stats.StatsConfig(
name=str(self), transform_mutates_spec=self._MUTATES_ELEMENT_SPEC
name=str(self),
transform_mutates_spec=self._MUTATES_ELEMENT_SPEC,
),
(self._map_parent._stats,), # pylint: disable=protected-access
(parent_stats,),
execution_tracking_mode,
)

@functools.cached_property
Expand Down
1 change: 1 addition & 0 deletions grain/python_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DatasetOptions,
WithOptionsIterDataset,
)
from ._src.python.dataset.stats import ExecutionTrackingMode
from ._src.python.dataset.transformations.flatmap import (
FlatMapMapDataset,
FlatMapIterDataset,
Expand Down
Loading