Skip to content

Commit

Permalink
change listener API, add basic support for task instance listeners in…
Browse files Browse the repository at this point in the history
… TaskSDK, make OpenLineage provider support Airflow 3's listener interface (apache#45294)

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski authored Feb 11, 2025
1 parent f4dfbf2 commit 0047a68
Show file tree
Hide file tree
Showing 43 changed files with 1,925 additions and 803 deletions.
4 changes: 4 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class DagRun(StrictBaseModel):
run_after: UtcDateTime
start_date: UtcDateTime
end_date: UtcDateTime | None
clear_number: int
run_type: DagRunType
conf: Annotated[dict[str, Any], Field(default_factory=dict)]
external_trigger: bool = False
Expand All @@ -238,6 +239,9 @@ class TIRunContext(BaseModel):
dag_run: DagRun
"""DAG run information for the task instance."""

task_reschedule_count: Annotated[int, Field(default=0)]
"""How many times the task has been rescheduled."""

max_tries: int
"""Maximum number of tries for the task instance (from DB)."""

Expand Down
36 changes: 31 additions & 5 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from fastapi import Body, HTTPException, status
from pydantic import JsonValue
from sqlalchemy import update
from sqlalchemy import func, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.sql import select

Expand Down Expand Up @@ -79,14 +79,23 @@ def ti_run(
ti_id_str = str(task_instance_id)

old = (
select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method, TI.max_tries)
select(
TI.state,
TI.dag_id,
TI.run_id,
TI.task_id,
TI.map_index,
TI.next_method,
TI.try_number,
TI.max_tries,
)
.where(TI.id == ti_id_str)
.with_for_update()
)
try:
(previous_state, dag_id, run_id, task_id, map_index, next_method, max_tries) = session.execute(
old
).one()
(previous_state, dag_id, run_id, task_id, map_index, next_method, try_number, max_tries) = (
session.execute(old).one()
)
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
Expand Down Expand Up @@ -147,6 +156,7 @@ def ti_run(
DR.run_after,
DR.start_date,
DR.end_date,
DR.clear_number,
DR.run_type,
DR.conf,
DR.logical_date,
Expand All @@ -171,8 +181,24 @@ def ti_run(
session=session,
)

task_reschedule_count = (
session.query(
func.count(TaskReschedule.id) # or any other primary key column
)
.filter(
TaskReschedule.dag_id == dag_id,
TaskReschedule.task_id == ti_id_str,
TaskReschedule.run_id == run_id,
# TaskReschedule.map_index == ti.map_index, # TODO: Handle mapped tasks
TaskReschedule.try_number == try_number,
)
.scalar()
or 0
)

return TIRunContext(
dag_run=dr,
task_reschedule_count=task_reschedule_count,
max_tries=max_tries,
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
Expand Down
44 changes: 15 additions & 29 deletions airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@

if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.utils.state import TaskInstanceState


# [START howto_listen_ti_running_task]
@hookimpl
def on_task_instance_running(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
def on_task_instance_running(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance):
"""
This method is called when task state changes to RUNNING.
Through callback, parameters like previous_task_state, task_instance object can be accessed.
Expand All @@ -39,14 +39,11 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T
print("Task instance is in running state")
print(" Previous state of the Task instance:", previous_state)

state: TaskInstanceState = task_instance.state
name: str = task_instance.task_id
start_date = task_instance.start_date

dagrun = task_instance.dag_run
dagrun_status = dagrun.state
context = task_instance.get_template_context()

task = task_instance.task
task = context["task"]

if TYPE_CHECKING:
assert task
Expand All @@ -55,16 +52,16 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T
dag_name = None
if dag:
dag_name = dag.dag_id
print(f"Current task name:{name} state:{state} start_date:{start_date}")
print(f"Dag name:{dag_name} and current dag run status:{dagrun_status}")
print(f"Current task name:{name}")
print(f"Dag name:{dag_name}")


# [END howto_listen_ti_running_task]


# [START howto_listen_ti_success_task]
@hookimpl
def on_task_instance_success(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
def on_task_instance_success(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance):
"""
This method is called when task state changes to SUCCESS.
Through callback, parameters like previous_task_state, task_instance object can be accessed.
Expand All @@ -74,14 +71,10 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T
print("Task instance in success state")
print(" Previous state of the Task instance:", previous_state)

dag_id = task_instance.dag_id
hostname = task_instance.hostname
operator = task_instance.operator
context = task_instance.get_template_context()
operator = context["task"]

dagrun = task_instance.dag_run
queued_at = dagrun.queued_at
print(f"Dag name:{dag_id} queued_at:{queued_at}")
print(f"Task hostname:{hostname} operator:{operator}")
print(f"Task operator:{operator}")


# [END howto_listen_ti_success_task]
Expand All @@ -90,7 +83,7 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T
# [START howto_listen_ti_failure_task]
@hookimpl
def on_task_instance_failed(
previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, session
previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance, error: None | str | BaseException
):
"""
This method is called when task state changes to FAILED.
Expand All @@ -100,21 +93,14 @@ def on_task_instance_failed(
"""
print("Task instance in failure state")

start_date = task_instance.start_date
end_date = task_instance.end_date
duration = task_instance.duration

dagrun = task_instance.dag_run

task = task_instance.task
context = task_instance.get_template_context()
task = context["task"]

if TYPE_CHECKING:
assert task

dag = task.dag

print(f"Task start:{start_date} end:{end_date} duration:{duration}")
print(f"Task:{task} dag:{dag} dagrun:{dagrun}")
print("Task start")
print(f"Task:{task}")
if error:
print(f"Failure caused by {error}")

Expand Down
8 changes: 7 additions & 1 deletion airflow/listeners/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ class ListenerManager:
"""Manage listener registration and provides hook property for calling them."""

def __init__(self):
from airflow.listeners.spec import asset, dagrun, importerrors, lifecycle, taskinstance
from airflow.listeners.spec import (
asset,
dagrun,
importerrors,
lifecycle,
taskinstance,
)

self.pm = pluggy.PluginManager("airflow")
self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall)
Expand Down
15 changes: 4 additions & 11 deletions airflow/listeners/spec/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,26 @@
from pluggy import HookspecMarker

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.models.taskinstance import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.utils.state import TaskInstanceState

hookspec = HookspecMarker("airflow")


@hookspec
def on_task_instance_running(
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
):
def on_task_instance_running(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance):
"""Execute when task state changes to RUNNING. previous_state can be None."""


@hookspec
def on_task_instance_success(
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
):
def on_task_instance_success(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance):
"""Execute when task state changes to SUCCESS. previous_state can be None."""


@hookspec
def on_task_instance_failed(
previous_state: TaskInstanceState | None,
task_instance: TaskInstance,
task_instance: RuntimeTaskInstance,
error: None | str | BaseException,
session: Session | None,
):
"""Execute when task state changes to FAIL. previous_state can be None."""
7 changes: 4 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _run_raw_task(
TaskInstance.save_to_db(ti=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session
previous_state=TaskInstanceState.RUNNING, task_instance=ti
)

return None
Expand Down Expand Up @@ -1873,6 +1873,7 @@ def to_runtime_ti(self, context_from_server) -> RuntimeTaskInstanceProtocol:
max_tries=self.max_tries,
hostname=self.hostname,
_ti_context_from_server=context_from_server,
start_date=self.start_date,
)

return runtime_ti
Expand Down Expand Up @@ -2895,7 +2896,7 @@ def signal_handler(signum, frame):

# Run on_task_instance_running event
get_listener_manager().hook.on_task_instance_running(
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
previous_state=TaskInstanceState.QUEUED, task_instance=self
)

def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
Expand Down Expand Up @@ -3137,7 +3138,7 @@ def fetch_handle_failure_context(
callbacks = task.on_retry_callback if task else None

get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
)

return {
Expand Down
2 changes: 2 additions & 0 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@
"prev_end_date_success",
"reason",
"run_id",
"start_date",
"task",
"task_reschedule_count",
"task_instance",
"task_instance_key_str",
"test_mode",
Expand Down
20 changes: 11 additions & 9 deletions docs/apache-airflow/administration-and-deployment/listeners.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ For example if you want to implement a listener that uses the ``error`` field in
...
@hookimpl
def on_task_instance_failed(
self, previous_state, task_instance, error: None | str | BaseException, session
):
def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException):
# Handle error case here
pass
Expand All @@ -177,15 +175,19 @@ For example if you want to implement a listener that uses the ``error`` field in
...
@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, session):
def on_task_instance_failed(self, previous_state, task_instance):
# Handle no error case here
pass
List of changes in the listener interfaces since 2.8.0 when they were introduced:


+-----------------+-----------------------------+---------------------------------------+
| Airflow Version | Affected method | Change |
+=================+=============================+=======================================+
| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface |
+-----------------+-----------------------------+---------------------------------------+
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
| Airflow Version | Affected method | Change |
+=================+============================================+=========================================================================+
| 2.10.0 | ``on_task_instance_failed`` | An error field added to the interface |
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
| 3.0.0 | ``on_task_instance_running``, | ``session`` argument removed from task instance listeners, |
| | ``on_task_instance_success``, | ``task_instance`` object is now an instance of ``RuntimeTaskInstance`` |
| | ``on_task_instance_failed`` | |
+-----------------+--------------------------------------------+-------------------------------------------------------------------------+
2 changes: 2 additions & 0 deletions docs/apache-airflow/templates-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ Variable Type Description
| ``None``
``{{ prev_end_date_success }}`` `pendulum.DateTime`_ End date from prior successful :class:`~airflow.models.dagrun.DagRun` (if available).
| ``None``
``{{ start_date }}`` `pendulum.DateTime`_ Datetime of when current task has been started.
``{{ inlets }}`` list List of inlets declared on the task.
``{{ inlet_events }}`` dict[str, ...] Access past events of inlet assets. See :doc:`Assets <authoring-and-scheduling/datasets>`. Added in version 2.10.
``{{ outlets }}`` list List of outlets declared on the task.
``{{ outlet_events }}`` dict[str, ...] | Accessors to attach information to asset events that will be emitted by the current task.
| See :doc:`Assets <authoring-and-scheduling/datasets>`. Added in version 2.10.
``{{ dag }}`` DAG The currently running :class:`~airflow.models.dag.DAG`. You can read more about DAGs in :doc:`DAGs <core-concepts/dags>`.
``{{ task }}`` BaseOperator | The currently running :class:`~airflow.models.baseoperator.BaseOperator`. You can read more about Tasks in :doc:`core-concepts/operators`
``{{ task_reschedule_count }}`` int How many times current task has been rescheduled. Relevant to ``mode="reschedule"`` sensors.
``{{ macros }}`` | A reference to the macros package. See Macros_ below.
``{{ task_instance }}`` TaskInstance The currently running :class:`~airflow.models.taskinstance.TaskInstance`.
``{{ ti }}`` TaskInstance Same as ``{{ task_instance }}``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
"pool_slots": 1,
"queue": "default",
"priority_weight": 1,
"start_date": "2023-01-01T00:00:00+00:00",
"map_index": -1,
},
"dag_rel_path": "mock.py",
"log_path": "mock.log",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_queue_workload(self):
pool_slots=1,
queue="default",
priority_weight=1,
start_date=timezone.utcnow(),
),
dag_rel_path="mock.py",
log_path="mock.log",
Expand Down
Loading

0 comments on commit 0047a68

Please sign in to comment.