Skip to content

Commit

Permalink
Move out kill switch from in thread as it breaks the transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Dec 12, 2024
1 parent 27791f4 commit 5423fa3
Showing 1 changed file with 21 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ def run_in_thread(f: _F) -> _F:
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
logger.info("DataHub listener disabled by kill switch")
return cast(_F, wrapper)
if _RUN_IN_THREAD:
# A poor-man's timeout mechanism.
# This ensures that we don't hang the task if the extractors
Expand Down Expand Up @@ -370,6 +367,15 @@ def _extract_lineage(
redact_with_exclusions(v)
)

def check_kill_switch(self):
try:
if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true":
logger.info("DataHub listener disabled by kill switch")
return True
except Exception as e:
raise e
return False

@hookimpl
@run_in_thread
def on_task_instance_running(
Expand All @@ -378,6 +384,8 @@ def on_task_instance_running(
task_instance: "TaskInstance",
session: "Session", # This will always be QUEUED
) -> None:
if self.check_kill_switch():
return
self._set_log_level()

# This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508.
Expand Down Expand Up @@ -488,6 +496,10 @@ def on_task_instance_running(
def on_task_instance_finish(
self, task_instance: "TaskInstance", status: InstanceRunResult
) -> None:
if self.check_kill_switch():
return
self._set_log_level()

dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined]

if self.config.render_templates:
Expand Down Expand Up @@ -547,6 +559,9 @@ def on_task_instance_finish(
def on_task_instance_success(
self, previous_state: None, task_instance: "TaskInstance", session: "Session"
) -> None:
if self.check_kill_switch():
return

self._set_log_level()

logger.debug(
Expand All @@ -562,6 +577,9 @@ def on_task_instance_success(
def on_task_instance_failed(
self, previous_state: None, task_instance: "TaskInstance", session: "Session"
) -> None:
if self.check_kill_switch():
return

self._set_log_level()

logger.debug(
Expand Down

0 comments on commit 5423fa3

Please sign in to comment.