Skip to content

Commit

Permalink
Combine asset events fetching logic into one SQL query and clean up u…
Browse files Browse the repository at this point in the history
…nnecessary asset-triggered dag data (apache#46721)

* refactor(dag): simplify asset_triggered_dag_info content

it was {dag_id: (min_asset_event_date, max_asset_event_date)}
min_asset_event_date is no longer needed as we won't have data interval for asset triggered event

* refactor(scheduler_job_runner): merge asset event fetching logic

* refactor(scheduler_job_runner): rename asset_triggered_dag_info as triggered_date_by_dag
  • Loading branch information
Lee-W authored Feb 14, 2025
1 parent d1e617c commit 42a9492
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 39 deletions.
34 changes: 15 additions & 19 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from collections import Counter, defaultdict, deque
from collections.abc import Collection, Iterable, Iterator
from contextlib import ExitStack, suppress
from datetime import timedelta
from datetime import date, timedelta
from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, Callable
Expand Down Expand Up @@ -1198,17 +1198,17 @@ def _do_scheduling(self, session: Session) -> int:
@retry_db_transaction
def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError."""
query, asset_triggered_dag_info = DagModel.dags_needing_dagruns(session)
query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
all_dags_needing_dag_runs = set(query.all())
asset_triggered_dags = [
dag for dag in all_dags_needing_dag_runs if dag.dag_id in asset_triggered_dag_info
dag for dag in all_dags_needing_dag_runs if dag.dag_id in triggered_date_by_dag
]
non_asset_dags = all_dags_needing_dag_runs.difference(asset_triggered_dags)
self._create_dag_runs(non_asset_dags, session)
if asset_triggered_dags:
self._create_dag_runs_asset_triggered(
dag_models=asset_triggered_dags,
asset_triggered_dag_info=asset_triggered_dag_info,
triggered_date_by_dag=triggered_date_by_dag,
session=session,
)

Expand Down Expand Up @@ -1325,13 +1325,13 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
def _create_dag_runs_asset_triggered(
self,
dag_models: Collection[DagModel],
asset_triggered_dag_info: dict[str, tuple[datetime, datetime]],
triggered_date_by_dag: dict[str, datetime],
session: Session,
) -> None:
"""For DAGs that are triggered by assets, create dag runs."""
triggered_dates: dict[str, DateTime] = {
dag_id: timezone.coerce_datetime(last_asset_event_time)
for dag_id, (_, last_asset_event_time) in asset_triggered_dag_info.items()
for dag_id, last_asset_event_time in triggered_date_by_dag.items()
}

for dag_model in dag_models:
Expand All @@ -1350,30 +1350,26 @@ def _create_dag_runs_asset_triggered(
latest_dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)

triggered_date = triggered_dates[dag.dag_id]
previous_dag_run = session.scalar(
select(DagRun)
cte = (
select(func.max(DagRun.run_after).label("previous_dag_run_run_after"))
.where(
DagRun.dag_id == dag.dag_id,
DagRun.run_after < triggered_date,
DagRun.run_type == DagRunType.ASSET_TRIGGERED,
DagRun.run_after < triggered_date,
)
.order_by(DagRun.run_after.desc())
.limit(1)
.cte()
)
asset_event_filters = [
DagScheduleAssetReference.dag_id == dag.dag_id,
AssetEvent.timestamp <= triggered_date,
]
if previous_dag_run:
asset_event_filters.append(AssetEvent.timestamp > previous_dag_run.run_after)

asset_events = session.scalars(
select(AssetEvent)
.join(
DagScheduleAssetReference,
AssetEvent.asset_id == DagScheduleAssetReference.asset_id,
)
.where(*asset_event_filters)
.where(
DagScheduleAssetReference.dag_id == dag.dag_id,
AssetEvent.timestamp <= triggered_date,
AssetEvent.timestamp > func.coalesce(cte.c.previous_dag_run_run_after, date.min),
)
).all()

dag_run = dag.create_dagrun(
Expand Down
31 changes: 17 additions & 14 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,7 +2317,7 @@ def deactivate_deleted_dags(
dm.is_active = False

@classmethod
def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]:
def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, datetime]]:
"""
Return (and lock) a list of Dag objects that are due to create a new DagRun.
Expand All @@ -2341,26 +2341,29 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]
adrq_by_dag: dict[str, list[AssetDagRunQueue]] = defaultdict(list)
for r in session.scalars(select(AssetDagRunQueue)):
adrq_by_dag[r.target_dag_id].append(r)
dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {}
for dag_id, records in adrq_by_dag.items():
dag_statuses[dag_id] = {AssetUniqueKey.from_asset(x.asset): True for x in records}
ser_dags = SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses), session=session)

dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {
dag_id: {AssetUniqueKey.from_asset(adrq.asset): True for adrq in adrqs}
for dag_id, adrqs in adrq_by_dag.items()
}
ser_dags = SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses), session=session)
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not dag_ready(dag_id, cond=ser_dag.dag.timetable.asset_condition, statuses=statuses):
del adrq_by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
# TODO: make it more readable (rename it or make it attrs, dataclass or etc.)
asset_triggered_dag_info: dict[str, tuple[datetime, datetime]] = {}
for dag_id, records in adrq_by_dag.items():
times = sorted(x.created_at for x in records)
asset_triggered_dag_info[dag_id] = (times[0], times[-1])

# triggered dates for asset triggered dags
triggered_date_by_dag: dict[str, datetime] = {
dag_id: max(adrq.created_at for adrq in adrqs) for dag_id, adrqs in adrq_by_dag.items()
}
del adrq_by_dag
asset_triggered_dag_ids = set(asset_triggered_dag_info.keys())

asset_triggered_dag_ids = set(triggered_date_by_dag.keys())
if asset_triggered_dag_ids:
# exclude as max active runs has been reached
exclusion_list = set(
session.scalars(
select(DagModel.dag_id)
Expand All @@ -2373,8 +2376,8 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]
)
if exclusion_list:
asset_triggered_dag_ids -= exclusion_list
asset_triggered_dag_info = {
k: v for k, v in asset_triggered_dag_info.items() if k not in exclusion_list
triggered_date_by_dag = {
k: v for k, v in triggered_date_by_dag.items() if k not in exclusion_list
}

# We limit so that _one_ scheduler doesn't try to do all the creation of dag runs
Expand All @@ -2395,7 +2398,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]

return (
session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)),
asset_triggered_dag_info,
triggered_date_by_dag,
)

def calculate_dagrun_date_fields(
Expand Down
11 changes: 5 additions & 6 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,7 @@ def test__processor_dags_folder(self, session, testing_dag_bundle):
assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER

@pytest.mark.need_serialized_dag
def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, session, dag_maker):
def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, session, dag_maker):
asset1 = Asset(uri="test://asset1", group="test-group")
asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test-group")

Expand Down Expand Up @@ -2417,11 +2417,10 @@ def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, sessio
)
session.flush()

query, asset_triggered_dag_info = DagModel.dags_needing_dagruns(session)
assert len(asset_triggered_dag_info) == 1
assert dag.dag_id in asset_triggered_dag_info
first_queued_time, last_queued_time = asset_triggered_dag_info[dag.dag_id]
assert first_queued_time == DEFAULT_DATE
query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
assert len(triggered_date_by_dag) == 1
assert dag.dag_id in triggered_date_by_dag
last_queued_time = triggered_date_by_dag[dag.dag_id]
assert last_queued_time == DEFAULT_DATE + timedelta(hours=1)

def test_asset_expression(self, testing_dag_bundle, session: Session) -> None:
Expand Down

0 comments on commit 42a9492

Please sign in to comment.