diff --git a/gokart/conflict_prevention_lock/task_lock.py b/gokart/conflict_prevention_lock/task_lock.py index e67bf535..2e6f6b1a 100644 --- a/gokart/conflict_prevention_lock/task_lock.py +++ b/gokart/conflict_prevention_lock/task_lock.py @@ -8,6 +8,8 @@ logger = getLogger(__name__) +# deprecated; task lock will be implemented to worker.py instead of each tasks + class TaskLockParams(NamedTuple): redis_host: Optional[str] diff --git a/gokart/conflict_prevention_lock/task_lock_wrappers.py b/gokart/conflict_prevention_lock/task_lock_wrappers.py index cb7c5d1e..f40831fb 100644 --- a/gokart/conflict_prevention_lock/task_lock_wrappers.py +++ b/gokart/conflict_prevention_lock/task_lock_wrappers.py @@ -6,6 +6,8 @@ logger = getLogger(__name__) +# deprecated; task lock will be implemented to worker.py instead of each tasks + def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable): """Redis lock wrapper function for TargetOnKart.dump(). diff --git a/gokart/task_collision_lock/dependency_lock.py b/gokart/task_collision_lock/dependency_lock.py new file mode 100644 index 00000000..24d34267 --- /dev/null +++ b/gokart/task_collision_lock/dependency_lock.py @@ -0,0 +1,20 @@ +from logging import getLogger + +from gokart.task_collision_lock.task_lock import build_lock_key, get_task_lock +from gokart.utils import flatten + +logger = getLogger(__name__) + + +def check_dependency_lock(task, redis_host: str, redis_port: int) -> None: + required_tasks = flatten(task.requires()) + for required_task in required_tasks: + _check_lock(task=required_task, redis_host=redis_host, redis_port=redis_port) + + +def _check_lock(task, redis_host: str, redis_port: int) -> None: + lock_key = build_lock_key(task=task) + task_lock = get_task_lock(redis_host=redis_host, redis_port=redis_port, lock_key=lock_key) + logger.info(f'Task lock of {lock_key} locked.') + task_lock.release() + logger.info(f'Task lock of {lock_key} released.') diff --git a/gokart/task_collision_lock/run_lock.py b/gokart/task_collision_lock/run_lock.py new file mode 100644 index 00000000..e356f895 --- /dev/null +++ b/gokart/task_collision_lock/run_lock.py @@ -0,0 +1,24 @@ +from logging import getLogger + +from gokart.task_collision_lock.task_lock import build_lock_key, get_task_lock, set_lock_scheduler + +logger = getLogger(__name__) + + +def run_with_lock(task, redis_host: str, redis_port: int): + lock_key = build_lock_key(task=task) + task_lock = get_task_lock(redis_host=redis_host, redis_port=redis_port, lock_key=lock_key) + scheduler = set_lock_scheduler(task_lock=task_lock) + + try: + logger.info(f'Task lock of {lock_key} locked.') + result = task.run() + task_lock.release() + logger.info(f'Task lock of {lock_key} released.') + scheduler.shutdown() + return result + except BaseException as e: + logger.info(f'Task lock of {lock_key} released with BaseException.') + task_lock.release() + scheduler.shutdown() + raise e diff --git a/gokart/task_collision_lock/task_lock.py b/gokart/task_collision_lock/task_lock.py new file mode 100644 index 00000000..ffb5aab2 --- /dev/null +++ b/gokart/task_collision_lock/task_lock.py @@ -0,0 +1,66 @@ +import functools +from logging import getLogger + +import redis +from apscheduler.schedulers.background import BackgroundScheduler + +logger = getLogger(__name__) + + +class TaskLockException(Exception): + pass + + +class RedisClient: + _instances: dict = {} + + def __new__(cls, *args, **kwargs): + key = (args, tuple(sorted(kwargs.items()))) + if cls not in cls._instances: + cls._instances[cls] = {} + if key not in cls._instances[cls]: + cls._instances[cls][key] = super(RedisClient, cls).__new__(cls) + return cls._instances[cls][key] + + def __init__(self, host: str, port: int) -> None: + if not hasattr(self, '_redis_client'): + self._redis_client = redis.Redis(host=host, port=port) + + def get_redis_client(self): + return self._redis_client + + +REDIS_TIMEOUT = 180 +LOCK_EXTEND_SECONDS = 10 + + +def get_task_lock(redis_host: str, redis_port: int, lock_key: str) -> redis.lock.Lock: + redis_client = RedisClient(host=redis_host, port=redis_port).get_redis_client() + task_lock = redis.lock.Lock(redis=redis_client, name=lock_key, timeout=REDIS_TIMEOUT, thread_local=False) + if not task_lock.acquire(blocking=False): + # If lock is already taken by other task, raise TaskLockException immediately. + raise TaskLockException('Lock already taken by other task.') + return task_lock + + +def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int) -> None: + task_lock.extend(additional_time=redis_timeout, replace_ttl=True) + + +def set_lock_scheduler(task_lock: redis.lock.Lock) -> BackgroundScheduler: + scheduler = BackgroundScheduler() + extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=REDIS_TIMEOUT) + scheduler.add_job( + extend_lock, + 'interval', + seconds=LOCK_EXTEND_SECONDS, + max_instances=999999999, + misfire_grace_time=REDIS_TIMEOUT, + coalesce=False, + ) + scheduler.start() + return scheduler + + +def build_lock_key(task) -> str: + return task.output().path() diff --git a/gokart/worker.py b/gokart/worker.py index 1443d862..7b769464 100644 --- a/gokart/worker.py +++ b/gokart/worker.py @@ -62,6 +62,8 @@ from luigi.task_status import RUNNING from gokart.parameter import ExplicitBoolParameter +from gokart.task_collision_lock.dependency_lock import check_dependency_lock +from gokart.task_collision_lock.run_lock import run_with_lock logger = logging.getLogger(__name__) @@ -127,6 +129,9 @@ def __init__( check_complete_on_run: bool = False, task_completion_cache: Optional[Dict[str, Any]] = None, task_completion_check_at_run: bool = True, + collision_lock_at_run: bool = False, + collision_lock_redis_host: str | None = None, + collision_lock_redis_port: int | None = None, ) -> None: super(TaskProcess, self).__init__() self.task = task @@ -141,6 +146,10 @@ def __init__( self.task_completion_cache = task_completion_cache self.task_completion_check_at_run = task_completion_check_at_run + self.collision_lock_at_run = collision_lock_at_run + self.collision_lock_redis_host = collision_lock_redis_host + self.collision_lock_redis_port = collision_lock_redis_port + # completeness check using the cache self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache) @@ -148,7 +157,18 @@ def _run_task(self) -> Optional[collections.abc.Generator]: if self.task_completion_check_at_run and self.check_complete(self.task): logger.warning(f'{self.task} is skipped because the task is already completed.') return None - return self.task.run() + + if not self.collision_lock_at_run: + return self.task.run() + + assert self.collision_lock_redis_host is not None, 'collision_lock_redis_host must be set to use lock.' + assert self.collision_lock_redis_port is not None, 'collision_lock_redis_port must be set to use lock.' + + # Check required tasks are not locked by others, which means they are running it. + check_dependency_lock(task=self.task, redis_host=self.collision_lock_redis_host, redis_port=self.collision_lock_redis_port) + + # Acquire an exclusion lock while running the task to prevent collision with other jobs. + return run_with_lock(task=self.task, redis_host=self.collision_lock_redis_host, redis_port=self.collision_lock_redis_port) def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]: task_gen = self._run_task() @@ -379,6 +399,10 @@ class gokart_worker(luigi.Config): default=True, description='If true, tasks completeness will be re-checked just before the run, in case they are finished elsewhere.' ) + collision_lock_at_run: bool = ExplicitBoolParameter(default=False, description='If true, lock the task at run time to prevent collision with other tasks.') + collision_lock_redis_host: str | None = luigi.Parameter(default=None, description='Redis host for task collision lock.') + collision_lock_redis_port: int | None = luigi.IntParameter(default=None, description='Redis port for task collision lock.') + class Worker: """ @@ -916,6 +940,9 @@ def _create_task_process(self, task): check_complete_on_run=self._config.check_complete_on_run, task_completion_cache=self._task_completion_cache, task_completion_check_at_run=self._config.task_completion_check_at_run, + collision_lock_at_run=self._config.collision_lock_at_run, + collision_lock_redis_host=self._config.collision_lock_redis_host, + collision_lock_redis_port=self._config.collision_lock_redis_port, ) def _purge_children(self) -> None: