From 250c2f21a3563b984c610e7912d42722d216bafa Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Tue, 11 Feb 2025 04:33:40 -0800 Subject: [PATCH] PyGrain performance and debugging tool PiperOrigin-RevId: 725577000 --- MODULE.bazel | 1 - grain/BUILD | 2 - grain/_src/python/BUILD | 1 - grain/_src/python/dataset/BUILD | 4 - grain/_src/python/dataset/dataset.py | 16 - grain/_src/python/dataset/dataset_test.py | 50 --- grain/_src/python/dataset/stats.py | 17 - grain/_src/python/dataset/stats_test.py | 345 +-------------------- grain/_src/python/grain_pool.py | 46 +-- grain/python/stats/BUILD | 18 -- grain/python/stats/execution_summary.proto | 36 --- 11 files changed, 5 insertions(+), 531 deletions(-) delete mode 100644 grain/python/stats/BUILD delete mode 100644 grain/python/stats/execution_summary.proto diff --git a/MODULE.bazel b/MODULE.bazel index 77722fbd..fc290873 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -23,7 +23,6 @@ bazel_dep(name = "rules_python", version = "0.34.0") bazel_dep(name = "pybind11_bazel", version = "2.13.6") bazel_dep(name = "abseil-py", version = "2.1.0") bazel_dep(name = "abseil-cpp", version = "20240722.0") -bazel_dep(name = "protobuf", version = "29.0", repo_name = "com_google_protobuf") python = use_extension("@rules_python//python/extensions:python.bzl", "python") diff --git a/grain/BUILD b/grain/BUILD index c1c7a8a7..504c25dc 100644 --- a/grain/BUILD +++ b/grain/BUILD @@ -16,7 +16,6 @@ py_library( "python/experimental.py", "python/fast_proto.py", ], - data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"], srcs_version = "PY3", # Implicit build flag visibility = ["//visibility:public"], @@ -45,6 +44,5 @@ py_library( "//grain/_src/python/dataset/transformations:zip", "//grain/_src/python/experimental/example_packing:packing", "//grain/_src/python/testing:experimental", - "//grain/python/stats:execution_summary_py_pb2", ], ) diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 7d847e73..94cffe02 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -203,7 +203,6 @@ py_library( ":options", ":record", ":shared_memory_array", - "//grain/_src/core:config", "//grain/_src/core:parallel", "//grain/_src/core:tree_lib", "@abseil-py//absl/logging", diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 4ab644b7..8ee63fa3 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -43,7 +43,6 @@ py_library( "//grain/_src/python:grain_pool", "//grain/_src/python:options", "//grain/_src/python:shared_memory_array", - "//grain/python/stats:execution_summary_py_pb2", "@abseil-py//absl/logging", "@pypi//cloudpickle:pkg", "@pypi//numpy:pkg", @@ -63,7 +62,6 @@ py_test( "//grain/_src/core:transforms", "//grain/_src/python:options", "//grain/_src/python/testing:experimental", - "//grain/python/stats:execution_summary_py_pb2", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:flagsaver", "@abseil-py//absl/testing:parameterized", @@ -107,7 +105,6 @@ py_library( "//grain/_src/core:config", "//grain/_src/core:monitoring", "//grain/_src/core:tree_lib", - "//grain/python/stats:execution_summary_py_pb2", "@abseil-py//absl/logging", ], ) @@ -120,7 +117,6 @@ py_test( ":dataset", ":stats", "//grain/_src/core:transforms", - "//grain/python/stats:execution_summary_py_pb2", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:flagsaver", "@pypi//cloudpickle:pkg", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index b9276aef..6a24a83a 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -58,7 +58,6 @@ from grain._src.python import options as grain_options from grain._src.python.dataset import base from grain._src.python.dataset import stats as dataset_stats -from grain.python.stats import execution_summary_pb2 import numpy as np from grain._src.core import monitoring @@ -1288,18 +1287,3 @@ def apply_transformations( f"Transformation type: {transformation} is not supported." ) return ds - - -def get_execution_summary( - ds: DatasetIterator, -) -> execution_summary_pb2.ExecutionSummary: - """Returns the execution summary for the dataset.""" - # pylint: disable=protected-access - execution_stats = ds._stats - if not isinstance(execution_stats, dataset_stats._ExecutionStats): - raise ValueError( - "Set `grain_py_debug_mode` or set `execution_tracking_mode` in grain" - " options to `STAGE_TIMING` to enable execution statistics collection." - ) - return execution_stats._get_execution_summary() - # pylint: enable=protected-access diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 8f85ada9..035b6090 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -31,7 +31,6 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats import grain._src.python.testing.experimental as test_util -from grain.python.stats import execution_summary_pb2 import numpy as np from typing_extensions import override @@ -908,54 +907,5 @@ def test_conflicting_options(self): ) -class GetExecutionSummaryTest(parameterized.TestCase): - - def test_get_execution_summary_without_collection(self): - ds = dataset.MapDataset.range(10).shuffle(42) - ds = ds.to_iter_dataset() - it = ds.__iter__() - with self.assertRaisesRegex( - ValueError, - "Set `grain_py_debug_mode` or set `execution_tracking_mode` in grain" - " options to `STAGE_TIMING` to enable execution statistics collection.", - ): - dataset.get_execution_summary(it) - - @mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05) - @mock.patch.object(dataset_stats, "_LOG_EXECUTION_SUMMARY_PERIOD_SEC", 0.06) - @flagsaver.flagsaver(grain_py_debug_mode=True) - def test_execution_summary_with_logging(self): - with self.assertLogs(level="INFO") as logs: - ds = dataset.MapDataset.range(10).shuffle(42) - ds = ds.map(MapTransformAddingOne()) - ds = ds.to_iter_dataset() - it = ds.__iter__() - # Get execution summary after iterating through the dataset. - _ = list(it) - # reporting stats after 0.05 seconds. - time.sleep(0.1) - log_value = "Grain Dataset Execution Summary" - self.assertRegex("".join(logs.output), log_value) - - @mock.patch.object(dataset_stats, "_REPORTING_PERIOD_SEC", 0.05) - @mock.patch.object(dataset_stats, "_LOG_EXECUTION_SUMMARY_PERIOD_SEC", 0.06) - def test_execution_summary_with_no_logging(self): - with self.assertNoLogs(level="INFO"): - ds = dataset.MapDataset.range(10).shuffle(42) - ds = ds.map(MapTransformAddingOne()) - ds = ds.to_iter_dataset() - ds = dataset.WithOptionsIterDataset( - ds, - base.DatasetOptions( - execution_tracking_mode=base.ExecutionTrackingMode.STAGE_TIMING - ), - ) - it = ds.__iter__() - # Get execution summary after iterating through the dataset. - _ = list(it) - # reporting stats after 0.05 seconds. - time.sleep(0.1) - - if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 22ee4704..c4a03d95 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -31,7 +31,6 @@ from grain._src.core import monitoring as grain_monitoring from grain._src.core import tree_lib from grain._src.python.dataset import base -from grain.python.stats import execution_summary_pb2 from grain._src.core import monitoring @@ -713,20 +712,4 @@ def make_stats( ), ) -> Stats: """Produces statistics instance according to the current execution mode.""" - vis_output_dir = grain_config.config.py_dataset_visualization_output_dir - # Only None and "" are supported. - if vis_output_dir: - raise NotImplementedError( - "Saving the dataset graph to a file is not supported yet. Set" - " `grain_py_dataset_visualization_output_dir` to empty string to" - " produce visualization in the logs." - ) - if grain_config.config.py_debug_mode: - # In debug mode, we always log the execution summary. - config = dataclasses.replace(config, log_summary=True) - return _ExecutionStats(config, parents=parents) - if execution_tracking_mode == base.ExecutionTrackingMode.STAGE_TIMING: - return _ExecutionStats(config, parents=parents) - if vis_output_dir is not None: - return _VisualizationStats(config, parents=parents) return _NoopStats(config, parents=parents) diff --git a/grain/_src/python/dataset/stats_test.py b/grain/_src/python/dataset/stats_test.py index a0f9fe99..803e0bd2 100644 --- a/grain/_src/python/dataset/stats_test.py +++ b/grain/_src/python/dataset/stats_test.py @@ -23,7 +23,6 @@ from grain._src.core import transforms from grain._src.python.dataset import dataset from grain._src.python.dataset import stats -from grain.python.stats import execution_summary_pb2 from absl.testing import absltest @@ -101,7 +100,7 @@ "[]" ││ - ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:524) + ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:525) ││ ╲╱ {'data': "[]", @@ -233,347 +232,5 @@ def test_report(self): s = s._parents[0] s.report() - -class DebugModeStatsTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.enter_context(flagsaver.flagsaver(grain_py_debug_mode=True)) - - @mock.patch.object(stats, "_REPORTING_PERIOD_SEC", 0.05) - def test_record_stats(self): - s = _make_stats_tree(stats.make_stats) - self.assertIsInstance(s, stats._ExecutionStats) - flat_stats = [] - to_visit = [s] - while to_visit: - node = to_visit.pop(0) - flat_stats.append(node) - to_visit.extend(node._parents) - - reported_self_times = collections.defaultdict(int) - - def mock_report(node): - while node._self_times_buffer: - reported_self_times[id(node)] += node._self_times_buffer.pop() - for p in node._parents: - p.report() - - for node in flat_stats: - node.report = functools.partial(mock_report, node) - for node in flat_stats: - with node.record_self_time(offset_ns=10**9): - time.sleep(0.5) - time.sleep(0.05) - self_times = list(reported_self_times.values()) - self.assertLen(self_times, len(flat_stats)) - for self_time in self_times: - self.assertGreaterEqual(self_time, 1.05 * 10**9) - - @mock.patch.object(stats, "_REPORTING_PERIOD_SEC", 0.05) - def test_record_stats_thread_safe(self): - s = stats.make_stats(stats.StatsConfig(name="test_stats"), ()) - reported_self_time = 0 - - def mock_report(node): - while node._self_times_buffer: - nonlocal reported_self_time - reported_self_time += node._self_times_buffer.pop() - for p in node._parents: - p.report() - - s.report = functools.partial(mock_report, s) - - def record_self_time(): - with s.record_self_time(): - # Sleep releases GIL, so this will actually execute concurrently. - time.sleep(1) - - n_threads = 100 - recording_threads = [] - for _ in range(n_threads): - t = threading.Thread(target=record_self_time) - t.start() - recording_threads.append(t) - for t in recording_threads: - t.join() - time.sleep(0.05) - self.assertGreaterEqual(reported_self_time, n_threads) - - def test_picklable(self): - s = stats.make_stats(stats.StatsConfig(name="test_stats"), ()) - self.assertIsInstance(s, stats._ExecutionStats) - s = cloudpickle.loads(cloudpickle.dumps(s)) - self.assertIsInstance(s, stats._ExecutionStats) - with s.record_self_time(): - time.sleep(0.5) - s = cloudpickle.loads(cloudpickle.dumps(s)) - self.assertIsInstance(s, stats._ExecutionStats) - - def test_dataset_visualization(self): - ds = ( - dataset.MapDataset.range(10) - .seed(42) - .shuffle() - .slice(slice(1, None, 3)) - .map_with_index(_add_dummy_metadata) - .map(_identity) - .repeat(2) - ) - # Visualization graph is constructed while iterating through pipeline. - _ = list(ds) - self.assertIsInstance(ds._stats, stats._ExecutionStats) - self.assertEqual(ds._stats._visualize_dataset_graph(), _MAP_DATASET_REPR) - - def test_pretty_print_execution_summary(self): - dummy_summary = execution_summary_pb2.ExecutionSummary() - dummy_summary.nodes[0].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=0, - name="MapDatasetIterator(transform=_MapFnFromPreprocessingBuilder(preprocessing_builder=NextTokenAsTargetTextPreprocessingBuilder))", - inputs=[1], - wait_time_ratio=0.5, - total_processing_time_ns=0, - min_processing_time_ns=400_000, - max_processing_time_ns=0, - num_produced_elements=0, - output_spec="[]", - ) - ) - dummy_summary.nodes[1].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=1, - name="PrefetchDatasetIterator", - inputs=[2], - wait_time_ratio=0.5, - total_processing_time_ns=400_000, - min_processing_time_ns=400, - max_processing_time_ns=40000, - num_produced_elements=10, - output_spec="[]", - is_output=True, - is_prefetch=True, - ) - ) - dummy_summary.nodes[2].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=2, - name="MapMapDataset", - inputs=[3, 4], - wait_time_ratio=0.375, - total_processing_time_ns=400_000_000, - min_processing_time_ns=4000, - max_processing_time_ns=40_000_000, - num_produced_elements=10, - output_spec="[]", - ) - ) - dummy_summary.nodes[3].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=3, - name="RangeMapDataset", - wait_time_ratio=0.125, - total_processing_time_ns=4000_000_000, - min_processing_time_ns=400_000, - max_processing_time_ns=400_000_000, - num_produced_elements=10, - inputs=[], - output_spec="[]", - ) - ) - dummy_summary.nodes[4].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=4, - name="RangeMapDataset", - total_processing_time_ns=0, - wait_time_ratio=0, - min_processing_time_ns=400_000, - max_processing_time_ns=0, - num_produced_elements=0, - inputs=[], - output_spec="[]", - ) - ) - - expected_result = """ -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| id | name | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 4 | RangeMapDataset | [] | 0.00% | N/A | N/A | N/A | N/A | N/A | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 3 | RangeMapDataset | [] | 12.50% | 4.00s | 400.00us | 400.00ms | 400.00ms | 10 | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 2 | MapMapDataset | [3, 4] | 37.50% | 400.00ms | 4.00us | 40.00ms | 40.00ms | 10 | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 1 | PrefetchDatasetIterator | [2] | N/A | 400.00us | 400ns | 40.00us | 40.00us | 10 | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| 0 | MapDatasetIterator(transform=_ | [1] | 50.00% | N/A | N/A | N/A | N/A | N/A | -| | MapFnFromPreprocessingBuilder( | | | | | | | | -| | preprocessing_builder=NextToke | | | | | | | | -| | nAsTargetTextPreprocessingBuil | | | | | | | | -| | der)) | | | | | | | | -|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -""" - self.assertEqual( - expected_result, - "\n" + stats._pretty_format_summary(dummy_summary), - ) - - def test_compute_iterator_wait_time_ratio(self): - dummy_summary = execution_summary_pb2.ExecutionSummary() - dummy_summary.nodes[0].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=0, - name="MapDatasetIterator", - inputs=[1], - total_processing_time_ns=4000_000_000, - min_processing_time_ns=400, - max_processing_time_ns=40000, - num_produced_elements=10, - output_spec="[]", - is_output=True, - ) - ) - dummy_summary.nodes[1].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=1, - name="PrefetchDatasetIterator", - inputs=[2], - total_processing_time_ns=4000_000_000, - min_processing_time_ns=4000, - max_processing_time_ns=40_000_000, - num_produced_elements=10, - output_spec="[]", - is_prefetch=True, - ) - ) - dummy_summary.nodes[2].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=2, - name="MapMapDataset", - inputs=[3], - total_processing_time_ns=1000_000_000, - min_processing_time_ns=400_000, - max_processing_time_ns=400_000_000, - num_produced_elements=10, - output_spec="[]", - ) - ) - dummy_summary.nodes[3].CopyFrom( - execution_summary_pb2.ExecutionSummary.Node( - id=3, - name="RangeMapDataset", - total_processing_time_ns=3000_000_000, - min_processing_time_ns=400_000, - max_processing_time_ns=4000_000, - num_produced_elements=10, - inputs=[], - output_spec="[]", - ) - ) - stats._populate_wait_time_ratio(dummy_summary) - self.assertEqual(dummy_summary.nodes[0].wait_time_ratio, 0.5) - self.assertEqual(dummy_summary.nodes[1].wait_time_ratio, 0) - self.assertEqual(dummy_summary.nodes[2].wait_time_ratio, 0.125) - self.assertEqual(dummy_summary.nodes[3].wait_time_ratio, 0.375) - - @flagsaver.flagsaver(grain_py_dataset_visualization_output_dir="TEST_DIR") - def test_dataset_visualization_with_output_dir(self): - ds = ( - dataset.MapDataset.range(10) - .shuffle(42) - .map_with_index(_add_dummy_metadata) - .map(_identity) - ) - with self.assertRaisesRegex( - NotImplementedError, - "Saving the dataset graph to a file is not supported yet.", - ): - _ = list(ds) - - -class GraphModeStatsTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.enter_context( - flagsaver.flagsaver(grain_py_dataset_visualization_output_dir="") - ) - - def test_visualize_map(self): - ds = ( - dataset.MapDataset.range(10) - .seed(42) - .shuffle() - .slice(slice(1, None, 3)) - .map_with_index(_add_dummy_metadata) - .map(_identity) - .repeat(2) - ) - # Visualization graph is constructed while iterating through pipeline. - _ = list(ds) - self.assertIsInstance(ds._stats, stats._VisualizationStats) - self.assertEqual(ds._stats._visualize_dataset_graph(), _MAP_DATASET_REPR) - - def test_visualize_iter(self): - ds = ( - dataset.MapDataset.range(10) - .shuffle(42) - .to_iter_dataset() - .seed(42) - .map(lambda x: _add_dummy_metadata(2, x)) - .batch(2) - ) - # Visualization graph is constructed while iterating through pipeline. - it = ds.__iter__() - _ = list(it) - self.assertIsInstance(it._stats, stats._VisualizationStats) - self.assertEqual(it._stats._visualize_dataset_graph(), _ITER_DATASET_REPR) - - def test_visualize_with_mix(self): - ds1 = dataset.MapDataset.range(10).shuffle(42) - ds2 = dataset.MapDataset.range(10).shuffle(43) - ds = dataset.MapDataset.mix([ds1, ds2]).map(_AddOne()) - # Visualization graph is constructed while iterating through pipeline. - _ = list(ds) - self.assertIsInstance(ds._stats, stats._VisualizationStats) - self.assertEqual(ds._stats._visualize_dataset_graph(), _MIX_DATASET_REPR) - - @flagsaver.flagsaver(grain_py_dataset_visualization_output_dir="TEST_DIR") - def test_dataset_visualization_with_output_dir(self): - ds = ( - dataset.MapDataset.range(10) - .shuffle(42) - .map_with_index(_add_dummy_metadata) - .map(_identity) - ) - with self.assertRaisesRegex( - NotImplementedError, - "Saving the dataset graph to a file is not supported yet.", - ): - _ = list(ds) - - def test_picklable(self): - ds = ( - dataset.MapDataset.range(10) - .seed(42) - .shuffle() - .slice(slice(1, None, 3)) - .map_with_index(_add_dummy_metadata) - .map(_identity) - .repeat(2) - ) - ds = cloudpickle.loads(cloudpickle.dumps(ds)) - # Visualization graph is constructed while iterating through pipeline. - _ = list(ds) - self.assertIsInstance(ds._stats, stats._VisualizationStats) - self.assertEqual(ds._stats._visualize_dataset_graph(), _MAP_DATASET_REPR) - - @flagsaver.flagsaver(grain_py_dataset_visualization_output_dir=None) - def test_dataset_visualization_with_output_dir_none(self): - s = stats.make_stats(stats.StatsConfig(name="test_stats"), ()) - self.assertIsInstance(s, stats._NoopStats) - - if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index f4473143..53d9ae8e 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -61,7 +61,6 @@ import cloudpickle from grain._src.core import parallel from grain._src.core import tree_lib -from grain._src.core.config import config import multiprocessing as mp from grain._src.python import grain_logging from grain._src.python import multiprocessing_common @@ -154,35 +153,12 @@ def deserialize(cls, serialized: bytes) -> GetElementProducerFn[T]: return obj -def parse_debug_flags(debug_flags: dict[str, Any]): - """Parses debug flags.""" - from absl import flags - flags.FLAGS["grain_py_debug_mode"].present = True - flags.FLAGS["grain_py_dataset_visualization_output_dir"].present = True - config.update("py_debug_mode", debug_flags["grain_py_debug_mode"]) - config.update( - "py_dataset_visualization_output_dir", - debug_flags["grain_py_dataset_visualization_output_dir"], - ) - def _initialize_and_get_element_producer( - args_queue: queues.Queue, - *, - debug_flags: dict[str, Any], - worker_index: int, - worker_count: int, + args_queue: queues.Queue, *, worker_index: int, worker_count: int ) -> Iterator[Any]: """Unpickles the element producer from the args queue and closes the queue.""" - ( - serialized_flag_parse_fn, - serialized_init_fn, - serialized_element_producer_fn, - ) = args_queue.get() - flag_parse_fn: Callable[[Any], None] = cloudpickle.loads( - serialized_flag_parse_fn - ) - flag_parse_fn(debug_flags) + serialized_init_fn, serialized_element_producer_fn = args_queue.get() init_fn: Callable[[], None] = cloudpickle.loads(serialized_init_fn) init_fn() element_producer_fn: GetElementProducerFn[Any] = ( @@ -206,7 +182,6 @@ def _worker_loop( worker_index: int, worker_count: int, enable_profiling: bool, - debug_flags: dict[str, Any], ): """Code to be run on each child process.""" out_of_elements = False @@ -217,10 +192,7 @@ def _worker_loop( ) logging.info("Starting work.") element_producer = _initialize_and_get_element_producer( - args_queue, - debug_flags=debug_flags, - worker_index=worker_index, - worker_count=worker_count, + args_queue, worker_index=worker_index, worker_count=worker_count ) profiling_enabled = enable_profiling and worker_index == 0 if profiling_enabled: @@ -351,24 +323,14 @@ def __init__( "worker_index": worker_index, "worker_count": options.num_workers, "enable_profiling": options.enable_profiling, - "debug_flags": { - "grain_py_debug_mode": config.py_debug_mode, - "grain_py_dataset_visualization_output_dir": ( - config.py_dataset_visualization_output_dir - ), - }, } # The process kwargs must all be pickable and will be unpickle before # absl.app.run() is called. We send arguments via a queue to ensure that # they are unpickled after absl.app.run() was called in the child # processes. worker_init_fn = lambda: None - parse_debug_flags_fn = parse_debug_flags worker_init_fn = cloudpickle.dumps(worker_init_fn) - parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn) - worker_args_queue.put( - (parse_debug_flags_fn, worker_init_fn, get_element_producer_fn) - ) + worker_args_queue.put((worker_init_fn, get_element_producer_fn)) process = ctx.Process( # pytype: disable=attribute-error # re-none target=_worker_loop, kwargs=process_kwargs, daemon=True ) diff --git a/grain/python/stats/BUILD b/grain/python/stats/BUILD deleted file mode 100644 index 36058ad6..00000000 --- a/grain/python/stats/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") -load("@com_google_protobuf//bazel:py_proto_library.bzl", "py_proto_library") - -default_visibility = ["//grain:__subpackages__"] - -package(default_visibility = default_visibility) - -proto_library( - name = "execution_summary_proto", - srcs = ["execution_summary.proto"], - # For profiling tooling. -) - -py_proto_library( - name = "execution_summary_py_pb2", - # For profiling tooling. - deps = [":execution_summary_proto"], -) diff --git a/grain/python/stats/execution_summary.proto b/grain/python/stats/execution_summary.proto deleted file mode 100644 index f9d92299..00000000 --- a/grain/python/stats/execution_summary.proto +++ /dev/null @@ -1,36 +0,0 @@ -syntax = "proto3"; - -package grain.python.execution_summary; - -message ExecutionSummary { - message Node { - // Unique ID of the node. - int32 id = 2; - // Human-readable name of the node. - string name = 3; - // Node IDs of the parent nodes. - repeated int32 inputs = 4; - // Ratio of time spent by the pipeline waiting for the given transformation - // node. - double wait_time_ratio = 5; - // Cummulative processing time spent in the node from the start in - // nanoseconds. - int64 total_processing_time_ns = 6; - // Minimum per-element processing time in nanoseconds. - int64 min_processing_time_ns = 7; - // Maximum per-element processing time in nanoseconds. - int64 max_processing_time_ns = 8; - // Number of elements produced by the node. - int64 num_produced_elements = 9; - // Human-readable specification of the produced elements. - string output_spec = 10; - // Whether the node is the root node. - bool is_output = 11; - // Whether the node is prefetch node. Child nodes of prefetch will have - // their wait time ratio derived from the ratio of the prefetch node. - // Sum of all ratios in a single pipeline is 1. - bool is_prefetch = 12; - } - // Map of node IDs to nodes in the pipeline. - map nodes = 1; -}