Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add lock #12

Open
wants to merge 1 commit into
base: feature/rename_task_completion_check_at_run
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gokart/conflict_prevention_lock/task_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

logger = getLogger(__name__)

# deprecated; task lock will be implemented to worker.py instead of each tasks


class TaskLockParams(NamedTuple):
redis_host: Optional[str]
Expand Down
2 changes: 2 additions & 0 deletions gokart/conflict_prevention_lock/task_lock_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

logger = getLogger(__name__)

# deprecated; task lock will be implemented to worker.py instead of each tasks


def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable):
"""Redis lock wrapper function for TargetOnKart.dump().
Expand Down
20 changes: 20 additions & 0 deletions gokart/task_collision_lock/dependency_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from logging import getLogger

from gokart.task_collision_lock.task_lock import build_lock_key, get_task_lock
from gokart.utils import flatten

logger = getLogger(__name__)


def check_dependency_lock(task, redis_host: str, redis_port: int) -> None:
required_tasks = flatten(task.requires())
for required_task in required_tasks:
_check_lock(task=required_task, redis_host=redis_host, redis_port=redis_port)


def _check_lock(task, redis_host: str, redis_port: int) -> None:
lock_key = build_lock_key(task=task)
task_lock = get_task_lock(redis_host=redis_host, redis_port=redis_port, lock_key=lock_key)
logger.info(f'Task lock of {lock_key} locked.')
task_lock.release()
logger.info(f'Task lock of {lock_key} released.')
24 changes: 24 additions & 0 deletions gokart/task_collision_lock/run_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from logging import getLogger

from gokart.task_collision_lock.task_lock import build_lock_key, get_task_lock, set_lock_scheduler

logger = getLogger(__name__)


def run_with_lock(task, redis_host: str, redis_port: int):
lock_key = build_lock_key(task=task)
task_lock = get_task_lock(redis_host=redis_host, redis_port=redis_port, lock_key=lock_key)
scheduler = set_lock_scheduler(task_lock=task_lock)

try:
logger.info(f'Task lock of {lock_key} locked.')
result = task.run()
task_lock.release()
logger.info(f'Task lock of {lock_key} released.')
scheduler.shutdown()
return result
except BaseException as e:
logger.info(f'Task lock of {lock_key} released with BaseException.')
task_lock.release()
scheduler.shutdown()
raise e
66 changes: 66 additions & 0 deletions gokart/task_collision_lock/task_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import functools
from logging import getLogger

import redis
from apscheduler.schedulers.background import BackgroundScheduler

logger = getLogger(__name__)


class TaskLockException(Exception):
pass


class RedisClient:
_instances: dict = {}

def __new__(cls, *args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
if cls not in cls._instances:
cls._instances[cls] = {}
if key not in cls._instances[cls]:
cls._instances[cls][key] = super(RedisClient, cls).__new__(cls)
return cls._instances[cls][key]

def __init__(self, host: str, port: int) -> None:
if not hasattr(self, '_redis_client'):
self._redis_client = redis.Redis(host=host, port=port)

def get_redis_client(self):
return self._redis_client


REDIS_TIMEOUT = 180
LOCK_EXTEND_SECONDS = 10


def get_task_lock(redis_host: str, redis_port: int, lock_key: str) -> redis.lock.Lock:
redis_client = RedisClient(host=redis_host, port=redis_port).get_redis_client()
task_lock = redis.lock.Lock(redis=redis_client, name=lock_key, timeout=REDIS_TIMEOUT, thread_local=False)
if not task_lock.acquire(blocking=False):
# If lock is already taken by other task, raise TaskLockException immediately.
raise TaskLockException('Lock already taken by other task.')
return task_lock


def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int) -> None:
task_lock.extend(additional_time=redis_timeout, replace_ttl=True)


def set_lock_scheduler(task_lock: redis.lock.Lock) -> BackgroundScheduler:
scheduler = BackgroundScheduler()
extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=REDIS_TIMEOUT)
scheduler.add_job(
extend_lock,
'interval',
seconds=LOCK_EXTEND_SECONDS,
max_instances=999999999,
misfire_grace_time=REDIS_TIMEOUT,
coalesce=False,
)
scheduler.start()
return scheduler


def build_lock_key(task) -> str:
return task.output().path()
29 changes: 28 additions & 1 deletion gokart/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from luigi.task_status import RUNNING

from gokart.parameter import ExplicitBoolParameter
from gokart.task_collision_lock.dependency_lock import check_dependency_lock
from gokart.task_collision_lock.run_lock import run_with_lock

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,6 +129,9 @@ def __init__(
check_complete_on_run: bool = False,
task_completion_cache: Optional[Dict[str, Any]] = None,
task_completion_check_at_run: bool = True,
collision_lock_at_run: bool = False,
collision_lock_redis_host: str | None = None,
collision_lock_redis_port: int | None = None,
) -> None:
super(TaskProcess, self).__init__()
self.task = task
Expand All @@ -141,14 +146,29 @@ def __init__(
self.task_completion_cache = task_completion_cache
self.task_completion_check_at_run = task_completion_check_at_run

self.collision_lock_at_run = collision_lock_at_run
self.collision_lock_redis_host = collision_lock_redis_host
self.collision_lock_redis_port = collision_lock_redis_port

# completeness check using the cache
self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache)

def _run_task(self) -> Optional[collections.abc.Generator]:
if self.task_completion_check_at_run and self.check_complete(self.task):
logger.warning(f'{self.task} is skipped because the task is already completed.')
return None
return self.task.run()

if not self.collision_lock_at_run:
return self.task.run()

assert self.collision_lock_redis_host is not None, 'collision_lock_redis_host must be set to use lock.'
assert self.collision_lock_redis_port is not None, 'collision_lock_redis_port must be set to use lock.'

# Check required tasks are not locked by others, which means they are running it.
check_dependency_lock(task=self.task, redis_host=self.collision_lock_redis_host, redis_port=self.collision_lock_redis_port)

# Acquire an exclusion lock while running the task to prevent collision with other jobs.
return run_with_lock(task=self.task, redis_host=self.collision_lock_redis_host, redis_port=self.collision_lock_redis_port)

def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]:
task_gen = self._run_task()
Expand Down Expand Up @@ -379,6 +399,10 @@ class gokart_worker(luigi.Config):
default=True, description='If true, tasks completeness will be re-checked just before the run, in case they are finished elsewhere.'
)

collision_lock_at_run: bool = ExplicitBoolParameter(default=False, description='If true, lock the task at run time to prevent collision with other tasks.')
collision_lock_redis_host: str | None = luigi.Parameter(default=None, description='Redis host for task collision lock.')
collision_lock_redis_port: int | None = luigi.IntParameter(default=None, description='Redis port for task collision lock.')


class Worker:
"""
Expand Down Expand Up @@ -916,6 +940,9 @@ def _create_task_process(self, task):
check_complete_on_run=self._config.check_complete_on_run,
task_completion_cache=self._task_completion_cache,
task_completion_check_at_run=self._config.task_completion_check_at_run,
collision_lock_at_run=self._config.collision_lock_at_run,
collision_lock_redis_host=self._config.collision_lock_redis_host,
collision_lock_redis_port=self._config.collision_lock_redis_port,
)

def _purge_children(self) -> None:
Expand Down
Loading