diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 4d81e464e68d3..e35e8573a6dd6 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -27,7 +27,7 @@ from collections import Counter, defaultdict, deque from collections.abc import Collection, Iterable, Iterator from contextlib import ExitStack, suppress -from datetime import timedelta +from datetime import date, timedelta from functools import lru_cache, partial from itertools import groupby from typing import TYPE_CHECKING, Any, Callable @@ -1198,17 +1198,17 @@ def _do_scheduling(self, session: Session) -> int: @retry_db_transaction def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None: """Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError.""" - query, asset_triggered_dag_info = DagModel.dags_needing_dagruns(session) + query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session) all_dags_needing_dag_runs = set(query.all()) asset_triggered_dags = [ - dag for dag in all_dags_needing_dag_runs if dag.dag_id in asset_triggered_dag_info + dag for dag in all_dags_needing_dag_runs if dag.dag_id in triggered_date_by_dag ] non_asset_dags = all_dags_needing_dag_runs.difference(asset_triggered_dags) self._create_dag_runs(non_asset_dags, session) if asset_triggered_dags: self._create_dag_runs_asset_triggered( dag_models=asset_triggered_dags, - asset_triggered_dag_info=asset_triggered_dag_info, + triggered_date_by_dag=triggered_date_by_dag, session=session, ) @@ -1325,13 +1325,13 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - def _create_dag_runs_asset_triggered( self, dag_models: Collection[DagModel], - asset_triggered_dag_info: dict[str, tuple[datetime, datetime]], + triggered_date_by_dag: dict[str, datetime], session: Session, ) -> None: """For DAGs that are triggered by assets, create dag runs.""" triggered_dates: dict[str, DateTime] = { dag_id: timezone.coerce_datetime(last_asset_event_time) - for dag_id, (_, last_asset_event_time) in asset_triggered_dag_info.items() + for dag_id, last_asset_event_time in triggered_date_by_dag.items() } for dag_model in dag_models: @@ -1350,30 +1350,26 @@ def _create_dag_runs_asset_triggered( latest_dag_version = DagVersion.get_latest_version(dag.dag_id, session=session) triggered_date = triggered_dates[dag.dag_id] - previous_dag_run = session.scalar( - select(DagRun) + cte = ( + select(func.max(DagRun.run_after).label("previous_dag_run_run_after")) .where( DagRun.dag_id == dag.dag_id, - DagRun.run_after < triggered_date, DagRun.run_type == DagRunType.ASSET_TRIGGERED, + DagRun.run_after < triggered_date, ) - .order_by(DagRun.run_after.desc()) - .limit(1) + .cte() ) - asset_event_filters = [ - DagScheduleAssetReference.dag_id == dag.dag_id, - AssetEvent.timestamp <= triggered_date, - ] - if previous_dag_run: - asset_event_filters.append(AssetEvent.timestamp > previous_dag_run.run_after) - asset_events = session.scalars( select(AssetEvent) .join( DagScheduleAssetReference, AssetEvent.asset_id == DagScheduleAssetReference.asset_id, ) - .where(*asset_event_filters) + .where( + DagScheduleAssetReference.dag_id == dag.dag_id, + AssetEvent.timestamp <= triggered_date, + AssetEvent.timestamp > func.coalesce(cte.c.previous_dag_run_run_after, date.min), + ) ).all() dag_run = dag.create_dagrun( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 37ff854c8da33..cff6c857c0376 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2317,7 +2317,7 @@ def deactivate_deleted_dags( dm.is_active = False @classmethod - def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]: + def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, datetime]]: """ Return (and lock) a list of Dag objects that are due to create a new DagRun. @@ -2341,11 +2341,12 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] adrq_by_dag: dict[str, list[AssetDagRunQueue]] = defaultdict(list) for r in session.scalars(select(AssetDagRunQueue)): adrq_by_dag[r.target_dag_id].append(r) - dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {} - for dag_id, records in adrq_by_dag.items(): - dag_statuses[dag_id] = {AssetUniqueKey.from_asset(x.asset): True for x in records} - ser_dags = SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses), session=session) + dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = { + dag_id: {AssetUniqueKey.from_asset(adrq.asset): True for adrq in adrqs} + for dag_id, adrqs in adrq_by_dag.items() + } + ser_dags = SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses), session=session) for ser_dag in ser_dags: dag_id = ser_dag.dag_id statuses = dag_statuses[dag_id] @@ -2353,14 +2354,16 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] del adrq_by_dag[dag_id] del dag_statuses[dag_id] del dag_statuses - # TODO: make it more readable (rename it or make it attrs, dataclass or etc.) - asset_triggered_dag_info: dict[str, tuple[datetime, datetime]] = {} - for dag_id, records in adrq_by_dag.items(): - times = sorted(x.created_at for x in records) - asset_triggered_dag_info[dag_id] = (times[0], times[-1]) + + # triggered dates for asset triggered dags + triggered_date_by_dag: dict[str, datetime] = { + dag_id: max(adrq.created_at for adrq in adrqs) for dag_id, adrqs in adrq_by_dag.items() + } del adrq_by_dag - asset_triggered_dag_ids = set(asset_triggered_dag_info.keys()) + + asset_triggered_dag_ids = set(triggered_date_by_dag.keys()) if asset_triggered_dag_ids: + # exclude as max active runs has been reached exclusion_list = set( session.scalars( select(DagModel.dag_id) @@ -2373,8 +2376,8 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] ) if exclusion_list: asset_triggered_dag_ids -= exclusion_list - asset_triggered_dag_info = { - k: v for k, v in asset_triggered_dag_info.items() if k not in exclusion_list + triggered_date_by_dag = { + k: v for k, v in triggered_date_by_dag.items() if k not in exclusion_list } # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs @@ -2395,7 +2398,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool] return ( session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)), - asset_triggered_dag_info, + triggered_date_by_dag, ) def calculate_dagrun_date_fields( diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index aee56bee883e3..a5330b174039b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2377,7 +2377,7 @@ def test__processor_dags_folder(self, session, testing_dag_bundle): assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER @pytest.mark.need_serialized_dag - def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, session, dag_maker): + def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, session, dag_maker): asset1 = Asset(uri="test://asset1", group="test-group") asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test-group") @@ -2417,11 +2417,10 @@ def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, sessio ) session.flush() - query, asset_triggered_dag_info = DagModel.dags_needing_dagruns(session) - assert len(asset_triggered_dag_info) == 1 - assert dag.dag_id in asset_triggered_dag_info - first_queued_time, last_queued_time = asset_triggered_dag_info[dag.dag_id] - assert first_queued_time == DEFAULT_DATE + query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session) + assert len(triggered_date_by_dag) == 1 + assert dag.dag_id in triggered_date_by_dag + last_queued_time = triggered_date_by_dag[dag.dag_id] assert last_queued_time == DEFAULT_DATE + timedelta(hours=1) def test_asset_expression(self, testing_dag_bundle, session: Session) -> None: