Skip to content

Commit

Permalink
AIP-82 Handle paused DAGs (#44456)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored Dec 6, 2024
1 parent 662f6e2 commit 77c115f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 13 deletions.
31 changes: 29 additions & 2 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from airflow.assets.manager import asset_manager
from airflow.models.asset import (
AssetActive,
AssetAliasModel,
AssetModel,
DagScheduleAssetAliasReference,
Expand Down Expand Up @@ -274,6 +275,26 @@ def _find_all_asset_aliases(dags: Iterable[DAG]) -> Iterator[AssetAlias]:
yield obj


def _find_active_assets(name_uri_assets, session: Session):
active_dags = {
dm.dag_id
for dm in session.scalars(select(DagModel).where(DagModel.is_active).where(~DagModel.is_paused))
}

return {
(asset_model.name, asset_model.uri)
for asset_model in session.scalars(
select(AssetModel)
.join(AssetActive, (AssetActive.name == AssetModel.name) & (AssetActive.uri == AssetModel.uri))
.where(tuple_(AssetActive.name, AssetActive.uri).in_(name_uri_assets))
.where(AssetModel.consuming_dags.any(DagScheduleAssetReference.dag_id.in_(active_dags)))
.options(
joinedload(AssetModel.consuming_dags).joinedload(DagScheduleAssetReference.dag),
)
).unique()
}


class AssetModelOperation(NamedTuple):
"""Collect asset/alias objects from DAGs and perform database operations for them."""

Expand Down Expand Up @@ -434,14 +455,20 @@ def add_asset_trigger_references(
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}

# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}

for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
asset_watchers = asset.watchers if name_uri in active_assets else []
trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
repr(trigger): trigger for trigger in asset.watchers
repr(trigger): trigger for trigger in asset_watchers
}
triggers.update(trigger_repr_to_trigger_dict)
trigger_repr_from_asset: set[str] = set(trigger_repr_to_trigger_dict.keys())

asset_model = assets[name_uri]
trigger_repr_from_asset_model: set[str] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
}
Expand Down
69 changes: 68 additions & 1 deletion tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,24 @@
from __future__ import annotations

import warnings
from collections.abc import Generator
from datetime import timedelta

import pytest
from sqlalchemy.exc import SAWarning

from airflow.dag_processing.collection import _get_latest_runs_stmt
from airflow.dag_processing.collection import AssetModelOperation, _get_latest_runs_stmt
from airflow.models import DagModel, Trigger
from airflow.models.asset import (
AssetActive,
asset_trigger_association_table,
)
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.definitions.asset import Asset
from airflow.utils.session import create_session

from tests_common.test_utils.db import clear_db_assets, clear_db_dags, clear_db_triggers


def test_statement_latest_runs_one_dag():
Expand Down Expand Up @@ -62,3 +76,56 @@ def test_statement_latest_runs_many_dag():
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_logical_date",
]
assert actual == expected, compiled_stmt


@pytest.mark.db_test
class TestAssetModelOperation:
@staticmethod
def clean_db():
clear_db_dags()
clear_db_assets()
clear_db_triggers()

@pytest.fixture(autouse=True)
def per_test(self) -> Generator:
self.clean_db()
yield
self.clean_db()

@pytest.mark.parametrize(
"is_active, is_paused, expected_num_triggers",
[
(True, True, 0),
(True, False, 1),
(False, True, 0),
(False, False, 0),
],
)
def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_triggers, dag_maker):
trigger = TimeDeltaTrigger(timedelta(seconds=0))
asset = Asset("test_add_asset_trigger_references_asset", watchers=[trigger])

with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag:
EmptyOperator(task_id="mytask")

asset_op = AssetModelOperation.collect({"test_add_asset_trigger_references_dag": dag})

with create_session() as session:
# Update `is_active` and `is_paused` properties from DAG
dags = session.query(DagModel).all()
for dag in dags:
dag.is_active = is_active
dag.is_paused = is_paused

orm_assets = asset_op.add_assets(session=session)
# Create AssetActive objects from assets. It is usually done in the scheduler
for asset in orm_assets.values():
session.add(AssetActive.for_asset(asset))
session.commit()

asset_op.add_asset_trigger_references(orm_assets, session=session)

session.commit()

assert session.query(Trigger).count() == expected_num_triggers
assert session.query(asset_trigger_association_table).count() == expected_num_triggers
20 changes: 10 additions & 10 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def test_bulk_write_to_db(self):
for i in range(4)
]

with assert_queries_count(5):
with assert_queries_count(6):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
Expand All @@ -647,14 +647,14 @@ def test_bulk_write_to_db(self):
assert row[0] is not None

# Re-sync should do fewer queries
with assert_queries_count(8):
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with assert_queries_count(8):
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
# Adding tags
for dag in dags:
dag.tags.add("test-dag2")
with assert_queries_count(9):
with assert_queries_count(10):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
Expand All @@ -673,7 +673,7 @@ def test_bulk_write_to_db(self):
# Removing tags
for dag in dags:
dag.tags.remove("test-dag")
with assert_queries_count(9):
with assert_queries_count(10):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
Expand All @@ -692,7 +692,7 @@ def test_bulk_write_to_db(self):
# Removing all tags
for dag in dags:
dag.tags = set()
with assert_queries_count(9):
with assert_queries_count(10):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
Expand All @@ -713,7 +713,7 @@ def test_bulk_write_to_db_single_dag(self):
for i in range(1)
]

with assert_queries_count(5):
with assert_queries_count(6):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()}
Expand All @@ -740,7 +740,7 @@ def test_bulk_write_to_db_multiple_dags(self):
for i in range(4)
]

with assert_queries_count(5):
with assert_queries_count(6):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
Expand All @@ -757,9 +757,9 @@ def test_bulk_write_to_db_multiple_dags(self):
assert row[0] is not None

# Re-sync should do fewer queries
with assert_queries_count(8):
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with assert_queries_count(8):
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)

@pytest.mark.parametrize("interval", [None, "@daily"])
Expand Down
15 changes: 15 additions & 0 deletions tests_common/test_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from tests_common.test_utils.compat import (
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
AssetDagRunQueue,
AssetEvent,
AssetModel,
Expand Down Expand Up @@ -120,6 +121,20 @@ def clear_db_assets():
from tests_common.test_utils.compat import AssetAliasModel

session.query(AssetAliasModel).delete()
if AIRFLOW_V_3_0_PLUS:
from airflow.models.asset import AssetActive, asset_trigger_association_table

session.query(asset_trigger_association_table).delete()
session.query(AssetActive).delete()


def clear_db_triggers():
with create_session() as session:
if AIRFLOW_V_3_0_PLUS:
from airflow.models.asset import asset_trigger_association_table

session.query(asset_trigger_association_table).delete()
session.query(Trigger).delete()


def clear_db_dags():
Expand Down

0 comments on commit 77c115f

Please sign in to comment.