diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index 85938aa5b9433..4aa684e70dfdb 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -125,9 +125,6 @@ def run_in_thread(f: _F) -> _F: @functools.wraps(f) def wrapper(*args, **kwargs): try: - if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true": - logger.info("DataHub listener disabled by kill switch") - return cast(_F, wrapper) if _RUN_IN_THREAD: # A poor-man's timeout mechanism. # This ensures that we don't hang the task if the extractors @@ -370,6 +367,15 @@ def _extract_lineage( redact_with_exclusions(v) ) + def check_kill_switch(self): + try: + if Variable.get(KILL_SWITCH_VARIABLE_NAME, "false").lower() == "true": + logger.info("DataHub listener disabled by kill switch") + return True + except Exception as e: + raise e + return False + @hookimpl @run_in_thread def on_task_instance_running( @@ -378,6 +384,8 @@ def on_task_instance_running( task_instance: "TaskInstance", session: "Session", # This will always be QUEUED ) -> None: + if self.check_kill_switch(): + return self._set_log_level() # This if statement mirrors the logic in https://github.com/OpenLineage/OpenLineage/pull/508. @@ -488,6 +496,10 @@ def on_task_instance_running( def on_task_instance_finish( self, task_instance: "TaskInstance", status: InstanceRunResult ) -> None: + if self.check_kill_switch(): + return + self._set_log_level() + dagrun: "DagRun" = task_instance.dag_run # type: ignore[attr-defined] if self.config.render_templates: @@ -547,6 +559,9 @@ def on_task_instance_finish( def on_task_instance_success( self, previous_state: None, task_instance: "TaskInstance", session: "Session" ) -> None: + if self.check_kill_switch(): + return + self._set_log_level() logger.debug( @@ -562,6 +577,9 @@ def on_task_instance_success( def on_task_instance_failed( self, previous_state: None, task_instance: "TaskInstance", session: "Session" ) -> None: + if self.check_kill_switch(): + return + self._set_log_level() logger.debug(