From 18a141440ce112425a388a8730d3264ba771e6b6 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 14 Nov 2024 03:49:04 +0000 Subject: [PATCH] Add complete interface --- .../runners/worker/statesampler_interface.py | 38 ++++++++++++++++++- .../runners/worker/statesampler_slow.py | 4 +- .../runners/worker/statesampler_stub.py | 14 ++++++- sdks/python/apache_beam/transforms/core.py | 12 +++--- .../apache_beam/transforms/ptransform_test.py | 17 ++++++--- 5 files changed, 68 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/statesampler_interface.py b/sdks/python/apache_beam/runners/worker/statesampler_interface.py index aed470036ffb..15fe7015cfac 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_interface.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_interface.py @@ -18,7 +18,41 @@ from abc import ABC, abstractmethod +class ScopedStateInterface(ABC): + @abstractmethod + def sampled_seconds(self) -> float: + pass + + @abstractmethod + def sampled_msecs_int(self) -> int: + pass + + @abstractmethod + def __enter__(self): + pass + + @abstractmethod + def __exit__(self, exc_type, exc_value, traceback): + pass + + class StateSamplerInterface(ABC): @abstractmethod - def update_metric(self, typed_metric_name, value): - raise NotImplementedError + def start(self) -> None: + pass + + @abstractmethod + def stop(self) -> None: + pass + + @abstractmethod + def reset(self) -> None: + pass + + @abstractmethod + def current_state(self) -> ScopedStateInterface: + pass + + @abstractmethod + def update_metric(self, typed_metric_name, value) -> None: + pass diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index eb05ee97a74a..40abc8a7cf64 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -21,7 +21,7 @@ from apache_beam.runners import common from apache_beam.utils import counters -from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface +from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface, ScopedStateInterface class StateSampler(StateSamplerInterface): @@ -73,7 +73,7 @@ def reset(self) -> None: pass -class ScopedState(object): +class ScopedState(ScopedStateInterface): def __init__( self, sampler: StateSampler, diff --git a/sdks/python/apache_beam/runners/worker/statesampler_stub.py b/sdks/python/apache_beam/runners/worker/statesampler_stub.py index 563eaed8be43..253e1261d038 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_stub.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_stub.py @@ -15,7 +15,7 @@ # limitations under the License. # -from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface +from apache_beam.runners.worker.statesampler_interface import StateSamplerInterface, ScopedStateInterface class StubStateSampler(StateSamplerInterface): @@ -30,3 +30,15 @@ def update_metric(self, typed_metric_name, value): def get_recorded_calls(self): return self._update_metric_calls + + def start(self) -> None: + raise NotImplementedError() + + def stop(self) -> None: + raise NotImplementedError() + + def reset(self) -> None: + raise NotImplementedError() + + def current_state(self) -> ScopedStateInterface: + raise NotImplementedError() diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index a11c37ad7660..5c3b2544308f 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2553,9 +2553,7 @@ def submit(self, process_fn, *args, **kwargs): stub_state_sampler.get_recorded_calls().items() ): tracker.update_metric(typed_metric_name, value) - if results is None: - return - return list(results) + return results class _SubprocessDoFn(DoFn): @@ -2596,10 +2594,10 @@ def _call_remote(self, method, *args, **kwargs): self._pool = concurrent.futures.ProcessPoolExecutor(1) self._pool.submit(self._remote_init, self._serialized_fn).result() try: - return _DeferredStateUpdatingPool( - self._pool, - self._timeout if method == self._remote_process else None).submit( - method, *args, **kwargs) + if (method == self._remote_process): + return _DeferredStateUpdatingPool(self._pool, self._timeout).submit( + method, *args, **kwargs) + return self._pool.submit(method, *args, **kwargs).result(None) except (concurrent.futures.process.BrokenProcessPool, TimeoutError, concurrent.futures._base.TimeoutError): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index fe011e964368..33511cee838e 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -18,7 +18,6 @@ """Unit tests for the PTransform and descendants.""" # pytype: skip-file - import collections import operator import os @@ -2769,10 +2768,18 @@ def process(self, element): yield element with TestPipeline() as p: - _, _ = ( - (p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn()) - .with_exception_handling( - use_subprocess=self.use_subprocess, timeout=1)) + good, _ = ( + (p + | beam.Create([1,2,3])) + | beam.ParDo(CounterDoFn()) + .with_exception_handling( + use_subprocess=self.use_subprocess, + timeout=1 if self.use_subprocess else .1 + ) + ) + + assert_that(good, equal_to([1, 2, 3]), label='CheckGood') + results = p.result metric_results1 = results.metrics().query(