Skip to content

Commit

Permalink
feat: make worker skip running task when it is completed (#406)
Browse files Browse the repository at this point in the history
* feat: add gokart_worker configurations as same as luigi one

* feat: make worker skip run when a task is completed
  • Loading branch information
hiro-o918 authored Dec 2, 2024
1 parent 29b29e3 commit d3eee04
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 6 deletions.
90 changes: 85 additions & 5 deletions gokart/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
from luigi.task_register import TaskClassException, load_task
from luigi.task_status import RUNNING

logger = logging.getLogger('luigi-interface')
from gokart.parameter import ExplicitBoolParameter

logger = logging.getLogger(__name__)

# Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex,
# that may not be unlocked in child process, resulting in the process being locked indefinitely.
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
check_unfulfilled_deps: bool = True,
check_complete_on_run: bool = False,
task_completion_cache: Optional[Dict[str, Any]] = None,
skip_if_completed_pre_run: bool = True,
) -> None:
super(TaskProcess, self).__init__()
self.task = task
Expand All @@ -136,12 +139,19 @@ def __init__(
self.check_unfulfilled_deps = check_unfulfilled_deps
self.check_complete_on_run = check_complete_on_run
self.task_completion_cache = task_completion_cache
self.skip_if_completed_pre_run = skip_if_completed_pre_run

# completeness check using the cache
self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache)

def _run_task(self) -> Optional[collections.abc.Generator]:
if self.skip_if_completed_pre_run and self.check_complete(self.task):
logger.warning(f'{self.task} is skipped because the task is already completed.')
return None
return self.task.run()

def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]:
task_gen = self.task.run()
task_gen = self._run_task()

if not isinstance(task_gen, collections.abc.Generator):
return None
Expand Down Expand Up @@ -308,6 +318,68 @@ def run(self) -> None:
super(ContextManagedTaskProcess, self).run()


class gokart_worker(luigi.Config):
"""Configuration for the gokart worker.
You can set these options of section [gokart_worker] in your luigi.cfg file.
NOTE: use snake_case for this class to match the luigi.Config convention.
"""

id = luigi.Parameter(default='', description='Override the auto-generated worker_id')
ping_interval = luigi.FloatParameter(default=1.0, config_path=dict(section='core', name='worker-ping-interval'))
keep_alive = luigi.BoolParameter(default=False, config_path=dict(section='core', name='worker-keep-alive'))
count_uniques = luigi.BoolParameter(
default=False,
config_path=dict(section='core', name='worker-count-uniques'),
description='worker-count-uniques means that we will keep a ' 'worker alive only if it has a unique pending task, as ' 'well as having keep-alive true',
)
count_last_scheduled = luigi.BoolParameter(
default=False, description='Keep a worker alive only if there are ' 'pending tasks which it was the last to ' 'schedule.'
)
wait_interval = luigi.FloatParameter(default=1.0, config_path=dict(section='core', name='worker-wait-interval'))
wait_jitter = luigi.FloatParameter(default=5.0)

max_keep_alive_idle_duration = luigi.TimeDeltaParameter(default=datetime.timedelta(0))

max_reschedules = luigi.IntParameter(default=1, config_path=dict(section='core', name='worker-max-reschedules'))
timeout = luigi.IntParameter(default=0, config_path=dict(section='core', name='worker-timeout'))
task_limit = luigi.IntParameter(default=None, config_path=dict(section='core', name='worker-task-limit'))
retry_external_tasks = luigi.BoolParameter(
default=False,
config_path=dict(section='core', name='retry-external-tasks'),
description='If true, incomplete external tasks will be ' 'retested for completion while Luigi is running.',
)
send_failure_email = luigi.BoolParameter(default=True, description='If true, send e-mails directly from the worker' 'on failure')
no_install_shutdown_handler = luigi.BoolParameter(default=False, description='If true, the SIGUSR1 shutdown handler will' 'NOT be install on the worker')
check_unfulfilled_deps = luigi.BoolParameter(default=True, description='If true, check for completeness of ' 'dependencies before running a task')
check_complete_on_run = luigi.BoolParameter(
default=False,
description='If true, only mark tasks as done after running if they are complete. '
'Regardless of this setting, the worker will always check if external '
'tasks are complete before marking them as done.',
)
force_multiprocessing = luigi.BoolParameter(default=False, description='If true, use multiprocessing also when ' 'running with 1 worker')
task_process_context = luigi.OptionalParameter(
default=None,
description='If set to a fully qualified class name, the class will '
'be instantiated with a TaskProcess as its constructor parameter and '
'applied as a context manager around its run() call, so this can be '
'used for obtaining high level customizable monitoring or logging of '
'each individual Task run.',
)
cache_task_completion = luigi.BoolParameter(
default=False,
description='If true, cache the response of successful completion checks '
'of tasks assigned to a worker. This can especially speed up tasks with '
'dynamic dependencies but assumes that the completion status does not change '
'after it was true the first time.',
)
skip_if_completed_pre_run: bool = ExplicitBoolParameter(
default=True, description='If true, skip running tasks that are already completed just before the Task is run.'
)


class Worker:
"""
Worker object communicates with a scheduler.
Expand All @@ -319,15 +391,22 @@ class Worker:
"""

def __init__(
self, scheduler: Optional[Scheduler] = None, worker_id: Optional[str] = None, worker_processes: int = 1, assistant: bool = False, **kwargs: Any
self,
scheduler: Optional[Scheduler] = None,
worker_id: Optional[str] = None,
worker_processes: int = 1,
assistant: bool = False,
config: Optional[gokart_worker] = None,
) -> None:
if scheduler is None:
scheduler = Scheduler()

self.worker_processes = int(worker_processes)
self._worker_info = self._generate_worker_info()

self._config = luigi.worker.worker(**kwargs)
if config is None:
self._config = gokart_worker()
else:
self._config = config

worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info)

Expand Down Expand Up @@ -836,6 +915,7 @@ def _create_task_process(self, task):
check_unfulfilled_deps=self._config.check_unfulfilled_deps,
check_complete_on_run=self._config.check_complete_on_run,
task_completion_cache=self._task_completion_cache,
skip_if_completed_pre_run=self._config.skip_if_completed_pre_run,
)

def _purge_children(self) -> None:
Expand Down
49 changes: 48 additions & 1 deletion test/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from unittest.mock import Mock

import luigi
import luigi.worker
import pytest
from luigi import scheduler

import gokart
from gokart.worker import Worker
from gokart.worker import Worker, gokart_worker


class _DummyTask(gokart.TaskOnKart):
Expand All @@ -33,3 +34,49 @@ def test_run(self, monkeypatch: pytest.MonkeyPatch):
assert worker.add(task)
assert worker.run()
mock_run.assert_called_once()


class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]):
task_namespace = __name__

def _run(self): ...

def run(self):
self._run()
self.dump(None)

def complete(self) -> bool:
return False


class TestWorkerSkipIfCompletedPreRun:
@pytest.mark.parametrize(
'skip_if_completed_pre_run,is_completed,expect_skipped',
[
pytest.param(True, True, True, id='skipped when completed and skip_if_completed_pre_run is True'),
pytest.param(True, False, False, id='not skipped when not completed and skip_if_completed_pre_run is True'),
pytest.param(False, True, False, id='not skipped when completed and skip_if_completed_pre_run is False'),
pytest.param(False, False, False, id='not skipped when not completed and skip_if_completed_pre_run is False'),
],
)
def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, skip_if_completed_pre_run: bool, is_completed: bool, expect_skipped: bool):
sch = scheduler.Scheduler()
worker = Worker(scheduler=sch, config=gokart_worker(skip_if_completed_pre_run=skip_if_completed_pre_run))

mock_complete = Mock(return_value=is_completed)
# NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic.
task = _DummyTaskToCheckSkip(complete_check_at_run=False)
mock_run = Mock()
monkeypatch.setattr(task, '_run', mock_run)

with worker:
assert worker.add(task)
# NOTE: mock `complete` after `add` because `add` calls `complete`
# to check if the task is already completed.
monkeypatch.setattr(task, 'complete', mock_complete)
assert worker.run()

if expect_skipped:
mock_run.assert_not_called()
else:
mock_run.assert_called_once()

0 comments on commit d3eee04

Please sign in to comment.