From fe5a2ea7a0b57901bb6239d666b875f6c71cd7e8 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 13 Feb 2025 22:25:21 +0530 Subject: [PATCH] AIP-72: Improving Operator Links Interface to Prevent User Code Execution in Webserver (#46613) Operator Links interface changed to not run user code in Airflow Webserver The Operator Extra links, which can be defined either via plugins or custom operators now do not execute any user code in the Airflow Webserver, but instead push the "full" links to XCom backend and the value is again fetched from the XCom backend when viewing task details in grid view. Example: ``` @attr.s(auto_attribs=True) class CustomBaseIndexOpLink(BaseOperatorLink): """Custom Operator Link for Google BigQuery Console.""" index: int = attr.ib() @property def name(self) -> str: return f"BigQuery Console #{self.index + 1}" @property def xcom_key(self) -> str: return f"bigquery_{self.index + 1}" def get_link(self, operator, *, ti_key): search_queries = XCom.get_one( task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query" ) if not search_queries: return None if len(search_queries) < self.index: return None search_query = search_queries[self.index] return f"https://console.cloud.google.com/bigquery?j={search_query}" ``` --- airflow/models/abstractoperator.py | 62 ----- airflow/models/baseoperatorlink.py | 54 ++++- airflow/serialization/serialized_objects.py | 137 ++++++----- newsfragments/46613.feature.rst | 1 + .../amazon/aws/links/test_base_aws.py | 11 +- .../dbt/cloud/operators/test_dbt.py | 2 +- .../google/cloud/operators/test_dataproc.py | 112 +++------ .../azure/operators/test_data_factory.py | 3 +- .../microsoft/azure/operators/test_powerbi.py | 2 +- .../microsoft/azure/operators/test_synapse.py | 2 +- .../airflow/sdk/execution_time/task_runner.py | 6 + .../tests/execution_time/test_task_runner.py | 43 ++++ .../endpoints/test_extra_link_endpoint.py | 5 + .../routes/public/test_extra_links.py | 138 ++++++----- tests/operators/test_trigger_dagrun.py | 6 +- tests/sensors/test_external_task_sensor.py | 2 +- tests/serialization/test_dag_serialization.py | 63 ++--- tests/www/views/test_views_extra_links.py | 219 ++++++++++-------- tests_common/pytest_plugin.py | 13 ++ tests_common/test_utils/mock_operators.py | 4 + 20 files changed, 465 insertions(+), 420 deletions(-) create mode 100644 newsfragments/46613.feature.rst diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 98fd977c59128..e1d909faa87bd 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -18,15 +18,12 @@ from __future__ import annotations import datetime -import inspect from collections.abc import Iterable, Sequence -from functools import cached_property from typing import TYPE_CHECKING, Any, Callable from sqlalchemy import select from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.sdk.definitions._internal.abstractoperator import ( AbstractOperator as TaskSDKAbstractOperator, NotMapped as NotMapped, # Re-export this for compat @@ -42,7 +39,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DAG as SchedulerDAG from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator @@ -157,64 +153,6 @@ def priority_weight_total(self) -> int: ) ) - @cached_property - def operator_extra_link_dict(self) -> dict[str, Any]: - """Returns dictionary of all extra links for the operator.""" - op_extra_links_from_plugin: dict[str, Any] = {} - from airflow import plugins_manager - - plugins_manager.initialize_extra_operators_links_plugins() - if plugins_manager.operator_extra_links is None: - raise AirflowException("Can't load operators") - for ope in plugins_manager.operator_extra_links: - if ope.operators and self.operator_class in ope.operators: - op_extra_links_from_plugin.update({ope.name: ope}) - - operator_extra_links_all = {link.name: link for link in self.operator_extra_links} - # Extra links defined in Plugins overrides operator links defined in operator - operator_extra_links_all.update(op_extra_links_from_plugin) - - return operator_extra_links_all - - @cached_property - def global_operator_extra_link_dict(self) -> dict[str, Any]: - """Returns dictionary of all global extra links.""" - from airflow import plugins_manager - - plugins_manager.initialize_extra_operators_links_plugins() - if plugins_manager.global_operator_extra_links is None: - raise AirflowException("Can't load operators") - return {link.name: link for link in plugins_manager.global_operator_extra_links} - - @cached_property - def extra_links(self) -> list[str]: - return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) - - def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: - """ - For an operator, gets the URLs that the ``extra_links`` entry points to. - - :meta private: - - :raise ValueError: The error message of a ValueError will be passed on through to - the fronted to show up as a tooltip on the disabled link. - :param ti: The TaskInstance for the URL being searched for. - :param link_name: The name of the link we're looking for the URL for. Should be - one of the options specified in ``extra_links``. - """ - link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name) - if not link: - link = self.global_operator_extra_link_dict.get(link_name) - if not link: - return None - - parameters = inspect.signature(link.get_link).parameters - old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD) - - if old_signature: - return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc] - return link.get_link(self.unmap(None), ti_key=ti.key) - def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]: """ Create the mapped task instances for mapped task. diff --git a/airflow/models/baseoperatorlink.py b/airflow/models/baseoperatorlink.py index 952d21bb5592d..a3eee044485f6 100644 --- a/airflow/models/baseoperatorlink.py +++ b/airflow/models/baseoperatorlink.py @@ -20,14 +20,53 @@ from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, ClassVar -import attr +import attrs + +from airflow.models.xcom import BaseXCom +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey -@attr.s(auto_attribs=True) +@attrs.define() +class XComOperatorLink(LoggingMixin): + """A generic operator link class that can retrieve link only using XCOMs. Used while deserializing operators.""" + + name: str + xcom_key: str + + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: + """ + Retrieve the link from the XComs. + + :param operator: The Airflow operator object this link is associated to. + :param ti_key: TaskInstance ID to return link for. + :return: link to external system, but by pulling it from XComs + """ + self.log.info( + "Attempting to retrieve link from XComs with key: %s for task id: %s", self.xcom_key, ti_key + ) + value = BaseXCom.get_one( + key=self.xcom_key, + run_id=ti_key.run_id, + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + map_index=ti_key.map_index, + ) + if not value: + self.log.debug( + "No link with name: %s present in XCom as key: %s, returning empty link", + self.name, + self.xcom_key, + ) + return "" + # Stripping is a temporary workaround till https://github.com/apache/airflow/issues/46513 is handled. + return value.strip('"') + + +@attrs.define() class BaseOperatorLink(metaclass=ABCMeta): """Abstract base class that defines how we get an operator link.""" @@ -44,6 +83,17 @@ class BaseOperatorLink(metaclass=ABCMeta): def name(self) -> str: """Name of the link. This will be the button name on the task UI.""" + @property + def xcom_key(self) -> str: + """ + XCom key with while the whole "link" for this operator link is stored. + + On retrieving with this key, the entire link is returned. + + Defaults to `_link_` if not provided. + """ + return f"_link_{self.__class__.__name__}" + @abstractmethod def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: """ diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index edbc6b4e5bdef..686d3c9988f75 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -40,13 +40,14 @@ from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.models.baseoperator import BaseOperator +from airflow.models.baseoperatorlink import BaseOperatorLink, XComOperatorLink from airflow.models.connection import Connection from airflow.models.dag import DAG, _get_model_data_interval from airflow.models.expandinput import ( EXPAND_INPUT_EMPTY, create_expand_input, ) -from airflow.models.taskinstance import SimpleTaskInstance +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager @@ -96,7 +97,6 @@ from inspect import Parameter from airflow.models import DagRun - from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.expandinput import ExpandInput from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.types import Operator @@ -1167,6 +1167,58 @@ def __init__(self, *args, **kwargs): self.template_fields = BaseOperator.template_fields self.operator_extra_links = BaseOperator.operator_extra_links + @cached_property + def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]: + """Returns dictionary of all extra links for the operator.""" + op_extra_links_from_plugin: dict[str, Any] = {} + from airflow import plugins_manager + + plugins_manager.initialize_extra_operators_links_plugins() + if plugins_manager.operator_extra_links is None: + raise AirflowException("Can't load operators") + for ope in plugins_manager.operator_extra_links: + if ope.operators and self.operator_class in ope.operators: + op_extra_links_from_plugin.update({ope.name: ope}) + + operator_extra_links_all = {link.name: link for link in self.operator_extra_links} + # Extra links defined in Plugins overrides operator links defined in operator + operator_extra_links_all.update(op_extra_links_from_plugin) + + return operator_extra_links_all + + @cached_property + def global_operator_extra_link_dict(self) -> dict[str, Any]: + """Returns dictionary of all global extra links.""" + from airflow import plugins_manager + + plugins_manager.initialize_extra_operators_links_plugins() + if plugins_manager.global_operator_extra_links is None: + raise AirflowException("Can't load operators") + return {link.name: link for link in plugins_manager.global_operator_extra_links} + + @cached_property + def extra_links(self) -> list[str]: + return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict)) + + def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None: + """ + For an operator, gets the URLs that the ``extra_links`` entry points to. + + :meta private: + + :raise ValueError: The error message of a ValueError will be passed on through to + the fronted to show up as a tooltip on the disabled link. + :param ti: The TaskInstance for the URL being searched for. + :param link_name: The name of the link we're looking for the URL for. Should be + one of the options specified in ``extra_links``. + """ + link = self.operator_extra_link_dict.get(link_name) + if not link: + link = self.global_operator_extra_link_dict.get(link_name) + if not link: + return None + return link.get_link(self.unmap(None), ti_key=ti.key) + @property def task_type(self) -> str: # Overwrites task_type of BaseOperator to use _task_type instead of @@ -1504,7 +1556,9 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): return super()._is_excluded(var, attrname, op) @classmethod - def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]: + def _deserialize_operator_extra_links( + cls, encoded_op_links: dict[str, str] + ) -> dict[str, XComOperatorLink]: """ Deserialize Operator Links if the Classes are registered in Airflow Plugins. @@ -1521,77 +1575,40 @@ def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, raise AirflowException("Can't load plugins") op_predefined_extra_links = {} - for _operator_links_source in encoded_op_links: - # Get the key, value pair as Tuple where key is OperatorLink ClassName - # and value is the dictionary containing the arguments passed to the OperatorLink - # - # Example of a single iteration: - # - # _operator_links_source = - # { - # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': { - # 'index': 0 - # } - # }, - # - # list(_operator_links_source.items()) = - # [ - # ( - # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink', - # {'index': 0} - # ) - # ] + for name, xcom_key in encoded_op_links.items(): + # Get the name and xcom_key of the encoded operator and use it to create a XComOperatorLink object + # during deserialization. # - # list(_operator_links_source.items())[0] = - # ( - # 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink', - # { - # 'index': 0 - # } - # ) - - _operator_link_class_path, data = next(iter(_operator_links_source.items())) - if _operator_link_class_path in get_operator_extra_links(): - single_op_link_class = import_string(_operator_link_class_path) - elif _operator_link_class_path in plugins_manager.registered_operator_link_classes: - single_op_link_class = plugins_manager.registered_operator_link_classes[ - _operator_link_class_path - ] - else: - log.error("Operator Link class %r not registered", _operator_link_class_path) - return {} - - op_link_parameters = {param: cls.deserialize(value) for param, value in data.items()} - op_predefined_extra_link: BaseOperatorLink = single_op_link_class(**op_link_parameters) - + # Example: + # enc_operator['_operator_extra_links'] = + # { + # 'airflow': 'airflow_link_key', + # 'foo-bar': 'link-key', + # 'no_response': 'key', + # 'raise_error': 'key' + # } + + op_predefined_extra_link = XComOperatorLink(name=name, xcom_key=xcom_key) op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link}) return op_predefined_extra_links @classmethod - def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]): + def _serialize_operator_extra_links( + cls, operator_extra_links: Iterable[BaseOperatorLink] + ) -> dict[str, str]: """ Serialize Operator Links. - Store the import path of the OperatorLink and the arguments passed to it. + Store the "name" of the link mapped with the xcom_key which can be later used to retrieve this + operator extra link from XComs. For example: - ``[{'airflow.providers.google.cloud.links.bigquery.BigQueryDatasetLink': {}}]`` + ``{'link-name-1': 'xcom-key-1'}`` :param operator_extra_links: Operator Link :return: Serialized Operator Link """ - serialize_operator_extra_links = [] - for operator_extra_link in operator_extra_links: - op_link_arguments = { - param: cls.serialize(value) for param, value in attrs.asdict(operator_extra_link).items() - } - - module_path = ( - f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}" - ) - serialize_operator_extra_links.append({module_path: op_link_arguments}) - - return serialize_operator_extra_links + return {link.name: link.xcom_key for link in operator_extra_links} @classmethod def serialize(cls, var: Any, *, strict: bool = False) -> Any: diff --git a/newsfragments/46613.feature.rst b/newsfragments/46613.feature.rst new file mode 100644 index 0000000000000..3f84bb48d8e46 --- /dev/null +++ b/newsfragments/46613.feature.rst @@ -0,0 +1 @@ +Operator Links interface changed to not run user code in Airflow Webserver The Operator Extra links, which can be defined either via plugins or custom operators now do not execute any user code in the Airflow Webserver, but instead push the "full" links to XCom backend and the value is again fetched from the XCom backend when viewing task details in grid view. diff --git a/providers/amazon/tests/provider_tests/amazon/aws/links/test_base_aws.py b/providers/amazon/tests/provider_tests/amazon/aws/links/test_base_aws.py index 91240a512ab32..1749d51e36e47 100644 --- a/providers/amazon/tests/provider_tests/amazon/aws/links/test_base_aws.py +++ b/providers/amazon/tests/provider_tests/amazon/aws/links/test_base_aws.py @@ -194,14 +194,7 @@ def assert_extra_link_url( ) error_msg = f"{self.full_qualname!r} should be preserved after execution" - assert ti.task.get_extra_links(ti, self.link_class.name) == expected_url, error_msg - - serialized_dag = self.dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[self.task_id] - - error_msg = f"{self.full_qualname!r} should be preserved in deserialized tasks after execution" - assert deserialized_task.get_extra_links(ti, self.link_class.name) == expected_url, error_msg + assert task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) == expected_url, error_msg def test_link_serialize(self): """Test: Operator links should exist for serialized DAG.""" @@ -223,7 +216,7 @@ def test_empty_xcom(self): deserialized_task = deserialized_dag.task_dict[self.task_id] assert ( - ti.task.get_extra_links(ti, self.link_class.name) == "" + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ), "Operator link should only be added if job id is available in XCom" assert ( diff --git a/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py b/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py index 8791a09a5ba71..c8d0b06c2863e 100644 --- a/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py +++ b/providers/dbt/cloud/tests/provider_tests/dbt/cloud/operators/test_dbt.py @@ -658,7 +658,7 @@ def test_run_job_operator_link(self, conn_id, account_id, create_task_instance_o ti.xcom_push(key="job_run_url", value=_run_response["data"]["href"]) - url = ti.task.get_extra_links(ti, "Monitor Job Run") + url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) assert url == ( EXPECTED_JOB_RUN_OP_EXTRA_LINK.format( diff --git a/providers/google/tests/provider_tests/google/cloud/operators/test_dataproc.py b/providers/google/tests/provider_tests/google/cloud/operators/test_dataproc.py index 3acc52b39ea95..22d22bd88485f 100644 --- a/providers/google/tests/provider_tests/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/provider_tests/google/cloud/operators/test_dataproc.py @@ -40,9 +40,6 @@ from airflow.providers.google.cloud.links.dataproc import ( DATAPROC_CLUSTER_LINK_DEPRECATED, DATAPROC_JOB_LINK_DEPRECATED, - DataprocClusterLink, - DataprocJobLink, - DataprocWorkflowLink, ) from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, @@ -55,7 +52,6 @@ DataprocGetBatchOperator, DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator, - DataprocLink, DataprocListBatchesOperator, DataprocScaleClusterOperator, DataprocStartClusterOperator, @@ -1126,29 +1122,21 @@ def test_create_cluster_operator_extra_links(dag_maker, create_task_instance_of_ ) serialized_dag = dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] # Assert operator links for serialized DAG deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Cluster" - # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == "" - - # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == "" + assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) - # Assert operator links are preserved in deserialized tasks after execution - assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED - # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert ( + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) + == DATAPROC_CLUSTER_LINK_EXPECTED + ) class TestDataprocClusterScaleOperator(DataprocClusterTestBase): @@ -1237,33 +1225,25 @@ def test_scale_cluster_operator_extra_links(dag_maker, create_task_instance_of_o ) serialized_dag = dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] # Assert operator links for serialized DAG deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc resource" - # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocLink) - # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocLink.name) == "" - - # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == "" + assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push( key="conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, ) - # Assert operator links are preserved in deserialized tasks after execution - assert deserialized_task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED - # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert ( + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) + == DATAPROC_CLUSTER_LINK_EXPECTED + ) class TestDataprocClusterDeleteOperator: @@ -2108,30 +2088,22 @@ def test_submit_job_operator_extra_links(mock_hook, dag_maker, create_task_insta ) serialized_dag = dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] # Assert operator links for serialized DAG deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Job" - # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocJobLink) - # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocJobLink.name) == "" - - # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocJobLink.name) == "" + assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) - # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED - # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocJobLink.name) == DATAPROC_JOB_LINK_EXPECTED + assert ( + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) + == DATAPROC_JOB_LINK_EXPECTED + ) class TestDataprocUpdateClusterOperator(DataprocClusterTestBase): @@ -2318,30 +2290,22 @@ def test_update_cluster_operator_extra_links(dag_maker, create_task_instance_of_ ) serialized_dag = dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] # Assert operator links for serialized DAG deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Cluster" - # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == "" - - # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == "" + assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) - # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED - # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == DATAPROC_CLUSTER_LINK_EXPECTED + assert ( + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) + == DATAPROC_CLUSTER_LINK_EXPECTED + ) class TestDataprocStartClusterOperator(DataprocClusterTestBase): @@ -2539,30 +2503,22 @@ def test_instantiate_workflow_operator_extra_links(mock_hook, dag_maker, create_ gcp_conn_id=GCP_CONN_ID, ) serialized_dag = dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] # Assert operator links for serialized DAG deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" - # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocWorkflowLink) - # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == "" - - # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == "" + assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) - # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED - # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED + assert ( + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) + == DATAPROC_WORKFLOW_LINK_EXPECTED + ) class TestDataprocWorkflowTemplateInstantiateInlineOperator: @@ -3211,30 +3167,22 @@ def test_instantiate_inline_workflow_operator_extra_links( gcp_conn_id=GCP_CONN_ID, ) serialized_dag = dag_maker.get_serialized_data() - deserialized_dag = SerializedDAG.from_dict(serialized_dag) - deserialized_task = deserialized_dag.task_dict[TASK_ID] # Assert operator links for serialized DAG deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"]) operator_extra_link = deserialized_dag.tasks[0].operator_extra_links[0] assert operator_extra_link.name == "Dataproc Workflow" - # Assert operator link types are preserved during deserialization - assert isinstance(deserialized_task.operator_extra_links[0], DataprocWorkflowLink) - # Assert operator link is empty when no XCom push occurred - assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == "" - - # Assert operator link is empty for deserialized task when no XCom push occurred - assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == "" + assert ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == "" ti.xcom_push(key="dataproc_workflow", value=DATAPROC_WORKFLOW_EXPECTED) - # Assert operator links are preserved in deserialized tasks - assert deserialized_task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED - # Assert operator links after execution - assert ti.task.get_extra_links(ti, DataprocWorkflowLink.name) == DATAPROC_WORKFLOW_LINK_EXPECTED + assert ( + ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) + == DATAPROC_WORKFLOW_LINK_EXPECTED + ) class TestDataprocCreateWorkflowTemplateOperator: diff --git a/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_data_factory.py b/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_data_factory.py index 1a5721e9d08c6..f1002f80ffcc8 100644 --- a/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_data_factory.py +++ b/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_data_factory.py @@ -246,8 +246,7 @@ def test_run_pipeline_operator_link(self, resource_group, factory, create_task_i factory_name=factory, ) ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) - - url = ti.task.get_extra_links(ti, "Monitor Pipeline Run") + url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = ( "https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" "?factory=/subscriptions/{subscription_id}/" diff --git a/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_powerbi.py b/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_powerbi.py index 3decae2718d10..2c4ae5fccfd8a 100644 --- a/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_powerbi.py +++ b/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_powerbi.py @@ -194,7 +194,7 @@ def test_powerbi_link(self, create_task_instance_of_operator): ) ti.xcom_push(key="powerbi_dataset_refresh_id", value=NEW_REFRESH_REQUEST_ID) - url = ti.task.get_extra_links(ti, "Monitor PowerBI Dataset") + url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) EXPECTED_ITEM_RUN_OP_EXTRA_LINK = ( "https://app.powerbi.com" # type: ignore[attr-defined] f"/groups/{GROUP_ID}/datasets/{DATASET_ID}" # type: ignore[attr-defined] diff --git a/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_synapse.py b/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_synapse.py index 738e7b3675bba..9c327e98b0fad 100644 --- a/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_synapse.py +++ b/providers/microsoft/azure/tests/provider_tests/microsoft/azure/operators/test_synapse.py @@ -288,7 +288,7 @@ def test_run_pipeline_operator_link(self, create_task_instance_of_operator): ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) - url = ti.task.get_extra_links(ti, "Monitor Pipeline Run") + url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = ( "https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}" diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 3b8801ff9fcf1..d11f6bec994df 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -748,6 +748,12 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger): def finalize(ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger): + # Pushing xcom for each operator extra links defined on the operator only. + for oe in ti.task.operator_extra_links: + link, xcom_key = oe.get_link(operator=ti.task, ti_key=ti.id), oe.xcom_key # type: ignore[arg-type] + log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) + _xcom_push(ti, key=xcom_key, value=link) + log.debug("Running finalizers", ti=ti) if state in [TerminalTIState.SUCCESS]: get_listener_manager().hook.on_task_instance_success( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index ef16e4e9c2872..930025cb42720 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -57,6 +57,7 @@ PrevSuccessfulDagRunResult, RuntimeCheckOnTask, SetRenderedFields, + SetXCom, StartupDetails, SucceedTask, TaskState, @@ -82,6 +83,8 @@ from airflow.utils import timezone from airflow.utils.state import TaskInstanceState +from tests_common.test_utils.mock_operators import AirflowLink + FAKE_BUNDLE = BundleInfo(name="anything", version="any") @@ -1146,6 +1149,46 @@ def execute(self, context): _, msg = run(runtime_ti, log=mock.MagicMock()) assert isinstance(msg, SucceedTask) + def test_task_run_with_operator_extra_links(self, create_runtime_ti, mock_supervisor_comms, time_machine): + """Test that a task can run with operator extra links defined and can set an xcom.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + class DummyTestOperator(BaseOperator): + operator_extra_links = (AirflowLink(),) + + def execute(self, context): + print("Hello from custom operator", self.operator_extra_links) + + task = DummyTestOperator(task_id="task_with_operator_extra_links") + + runtime_ti = create_runtime_ti(task=task) + + run(runtime_ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=SucceedTask( + state=TerminalTIState.SUCCESS, end_date=instant, task_outlets=[], outlet_events=[] + ), + log=mock.ANY, + ) + + finalize(runtime_ti, log=mock.MagicMock(), state=TerminalTIState.SUCCESS) + + mock_supervisor_comms.send_request.assert_any_call( + msg=SetXCom( + key="_link_AirflowLink", + value='"https://airflow.apache.org"', + dag_id="test_dag", + run_id="test_run", + task_id="task_with_operator_extra_links", + map_index=-1, + mapped_length=None, + type="SetXCom", + ), + log=mock.ANY, + ) + class TestXComAfterTaskExecution: @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index d9b6e1d45ec0c..dbd9c4981f7f3 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -143,6 +143,7 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + @pytest.mark.skip(reason="Legacy API tests.") @mock_plugin_manager(plugins=[]) def test_should_respond_200(self): XCom.set( @@ -160,6 +161,7 @@ def test_should_respond_200(self): assert response.status_code == 200, response.data assert response.json == {"Google Custom": "http://google.com/custom_base_link?search=TEST_LINK_VALUE"} + @pytest.mark.skip(reason="Legacy API tests.") @mock_plugin_manager(plugins=[]) def test_should_respond_200_missing_xcom(self): response = self.client.get( @@ -170,6 +172,7 @@ def test_should_respond_200_missing_xcom(self): assert response.status_code == 200, response.data assert response.json == {"Google Custom": None} + @pytest.mark.skip(reason="Legacy API tests.") @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links(self): XCom.set( @@ -190,6 +193,7 @@ def test_should_respond_200_multiple_links(self): "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=TEST_LINK_VALUE_2", } + @pytest.mark.skip(reason="Legacy API tests.") @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links_missing_xcom(self): response = self.client.get( @@ -200,6 +204,7 @@ def test_should_respond_200_multiple_links_missing_xcom(self): assert response.status_code == 200, response.data assert response.json == {"BigQuery Console #1": None, "BigQuery Console #2": None} + @pytest.mark.skip(reason="Legacy API tests.") def test_should_respond_200_support_plugins(self): class GoogleLink(BaseOperatorLink): name = "Google" diff --git a/tests/api_fastapi/core_api/routes/public/test_extra_links.py b/tests/api_fastapi/core_api/routes/public/test_extra_links.py index 278e42ed09841..0907a2877565a 100644 --- a/tests/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/tests/api_fastapi/core_api/routes/public/test_extra_links.py @@ -17,24 +17,20 @@ from __future__ import annotations import os -from urllib.parse import quote_plus import pytest from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom from airflow.plugins_manager import AirflowPlugin from airflow.utils import timezone -from airflow.utils.session import provide_session from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType from tests_common.test_utils.compat import BaseOperatorLink from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom from tests_common.test_utils.mock_operators import CustomOperator -from tests_common.test_utils.mock_plugins import mock_plugin_manager from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -43,7 +39,33 @@ pytestmark = pytest.mark.db_test -class TestExtraLinks: +class GoogleLink(BaseOperatorLink): + name = "Google" + + def get_link(self, operator, ti_key): + return "https://www.google.com" + + +class S3LogLink(BaseOperatorLink): + name = "S3" + operators = [CustomOperator] + + def get_link(self, operator, ti_key): + return f"https://s3.amazonaws.com/airflow-logs/{operator.dag_id}/{operator.task_id}/" + + +class AirflowPluginWithOperatorLinks(AirflowPlugin): + name = "test_plugin" + global_operator_extra_links = [ + GoogleLink(), + ] + operator_extra_links = [ + S3LogLink(), + ] + + +@pytest.mark.mock_plugin_manager(plugins=[]) +class TestGetExtraLinks: dag_id = "TEST_DAG_ID" dag_run_id = "TEST_DAG_RUN_ID" task_single_link = "TEST_SINGLE_LINK" @@ -58,16 +80,15 @@ def _clear_db(): clear_db_runs() clear_db_xcom() - @provide_session @pytest.fixture(autouse=True) - def setup(self, test_client, session=None) -> None: + def setup(self, test_client, dag_maker, request, session) -> None: """ Setup extra links for testing. :return: Dictionary with event extra link names with their corresponding link as the links. """ self._clear_db() - self.dag = self._create_dag() + self.dag = self._create_dag(dag_maker) DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) @@ -88,20 +109,16 @@ def setup(self, test_client, session=None) -> None: def teardown_method(self) -> None: self._clear_db() - def _create_dag(self): - with DAG(dag_id=self.dag_id, schedule=None, default_args={"start_date": self.default_time}) as dag: + def _create_dag(self, dag_maker): + with dag_maker( + dag_id=self.dag_id, schedule=None, default_args={"start_date": self.default_time}, serialized=True + ) as dag: CustomOperator(task_id=self.task_single_link, bash_command="TEST_LINK_VALUE") CustomOperator( task_id=self.task_multiple_links, bash_command=["TEST_LINK_VALUE_1", "TEST_LINK_VALUE_2"] ) - # Mapped task expanded over a list of bash_commands - CustomOperator.partial(task_id=self.task_mapped).expand( - bash_command=["TEST_LINK_VALUE_1", "TEST_LINK_VALUE_2"] - ) return dag - -class TestGetExtraLinks(TestExtraLinks): @pytest.mark.parametrize( "url, expected_status_code, expected_response", [ @@ -131,8 +148,7 @@ def test_should_respond_404(self, test_client, url, expected_status_code, expect assert response.status_code == expected_status_code assert response.json() == expected_response - @mock_plugin_manager(plugins=[]) - def test_should_respond_200(self, test_client): + def test_should_respond_200(self, dag_maker, test_client): XCom.set( key="search_query", value="TEST_LINK_VALUE", @@ -140,6 +156,14 @@ def test_should_respond_200(self, test_client): dag_id=self.dag_id, run_id=self.dag_run_id, ) + XCom.set( + key="_link_CustomOpLink", + value="http://google.com/custom_base_link?search=TEST_LINK_VALUE", + task_id=self.task_single_link, + dag_id=self.dag_id, + run_id=self.dag_run_id, + ) + response = test_client.get( f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_single_link}/links", ) @@ -149,7 +173,6 @@ def test_should_respond_200(self, test_client): "Google Custom": "http://google.com/custom_base_link?search=TEST_LINK_VALUE" } - @mock_plugin_manager(plugins=[]) def test_should_respond_200_missing_xcom(self, test_client): response = test_client.get( f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_single_link}/links", @@ -158,15 +181,33 @@ def test_should_respond_200_missing_xcom(self, test_client): assert response.status_code == 200 assert response.json() == {"Google Custom": None} - @mock_plugin_manager(plugins=[]) - def test_should_respond_200_multiple_links(self, test_client): + def test_should_respond_200_multiple_links(self, test_client, session): XCom.set( key="search_query", value=["TEST_LINK_VALUE_1", "TEST_LINK_VALUE_2"], task_id=self.task_multiple_links, dag_id=self.dag.dag_id, run_id=self.dag_run_id, + session=session, + ) + XCom.set( + key="bigquery_1", + value="https://console.cloud.google.com/bigquery?j=TEST_LINK_VALUE_1", + task_id=self.task_multiple_links, + dag_id=self.dag_id, + run_id=self.dag_run_id, + session=session, + ) + XCom.set( + key="bigquery_2", + value="https://console.cloud.google.com/bigquery?j=TEST_LINK_VALUE_2", + task_id=self.task_multiple_links, + dag_id=self.dag_id, + run_id=self.dag_run_id, + session=session, ) + session.commit() + response = test_client.get( f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_multiple_links}/links", ) @@ -177,7 +218,6 @@ def test_should_respond_200_multiple_links(self, test_client): "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=TEST_LINK_VALUE_2", } - @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links_missing_xcom(self, test_client): response = test_client.get( f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_multiple_links}/links", @@ -186,48 +226,19 @@ def test_should_respond_200_multiple_links_missing_xcom(self, test_client): assert response.status_code == 200 assert response.json() == {"BigQuery Console #1": None, "BigQuery Console #2": None} + @pytest.mark.mock_plugin_manager(plugins=[AirflowPluginWithOperatorLinks]) def test_should_respond_200_support_plugins(self, test_client): - class GoogleLink(BaseOperatorLink): - name = "Google" - - def get_link(self, operator, dttm): - return "https://www.google.com" - - class S3LogLink(BaseOperatorLink): - name = "S3" - operators = [CustomOperator] - - def get_link(self, operator, dttm): - return ( - f"https://s3.amazonaws.com/airflow-logs/{operator.dag_id}/" - f"{operator.task_id}/{quote_plus(dttm.isoformat())}" - ) - - class AirflowTestPlugin(AirflowPlugin): - name = "test_plugin" - global_operator_extra_links = [ - GoogleLink(), - ] - operator_extra_links = [ - S3LogLink(), - ] - - with mock_plugin_manager(plugins=[AirflowTestPlugin]): - response = test_client.get( - f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_single_link}/links", - ) + response = test_client.get( + f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_single_link}/links", + ) + + assert response, response.status_code == 200 + assert response.json() == { + "Google Custom": None, + "Google": "https://www.google.com", + "S3": ("https://s3.amazonaws.com/airflow-logs/TEST_DAG_ID/TEST_SINGLE_LINK/"), + } - assert response, response.status_code == 200 - assert response.json() == { - "Google Custom": None, - "Google": "https://www.google.com", - "S3": ( - "https://s3.amazonaws.com/airflow-logs/" - "TEST_DAG_ID/TEST_SINGLE_LINK/2020-01-01T00%3A00%3A00%2B00%3A00" - ), - } - - @mock_plugin_manager(plugins=[]) @pytest.mark.xfail(reason="TODO: TaskSDK need to fix this, Extra links should work for mapped operator") def test_should_respond_200_mapped_task_instance(self, test_client): map_index = 0 @@ -248,11 +259,10 @@ def test_should_respond_200_mapped_task_instance(self, test_client): "Google Custom": "http://google.com/custom_base_link?search=TEST_LINK_VALUE_1" } - @mock_plugin_manager(plugins=[]) def test_should_respond_404_invalid_map_index(self, test_client): response = test_client.get( f"/public/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_mapped}/links", params={"map_index": 4}, ) assert response.status_code == 404 - assert response.json() == {"detail": "DAG Run with ID = TEST_DAG_RUN_ID not found"} + assert response.json() == {"detail": "Task with ID = TEST_MAPPED_TASK not found"} diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 1c72dae332a0e..44d7ee1282800 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -100,10 +100,14 @@ def assert_extra_link(self, triggered_dag_run, triggering_task, session): ) .one() ) + with mock.patch( "airflow.providers.standard.operators.trigger_dagrun.build_airflow_url_with_query" ) as mock_build_url: - triggering_task.get_extra_links(triggering_ti, "Triggered DAG") + # This is equivalent of a task run calling this and pushing to xcom + triggering_task.operator_extra_links[0].get_link( + operator=triggering_task, ti_key=triggering_ti.key + ) assert mock_build_url.called args, _ = mock_build_url.call_args expected_args = { diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 239eca8cdc165..eef7585f0c6a7 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -1086,7 +1086,7 @@ def test_external_task_sensor_extra_link( app.config["SERVER_NAME"] = "" with app.app_context(): - url = ti.task.get_extra_links(ti, "External DAG") + url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) assert f"/dags/{expected_external_dag_id}/grid" in url diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 23cf5033f42a6..a2acc21526127 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -58,6 +58,7 @@ ) from airflow.hooks.base import BaseHook from airflow.models.baseoperator import BaseOperator +from airflow.models.baseoperatorlink import XComOperatorLink from airflow.models.connection import Connection from airflow.models.dag import DAG from airflow.models.dagbag import DagBag @@ -86,7 +87,6 @@ from airflow.utils.task_group import TaskGroup from airflow.utils.xcom import XCOM_RETURN_KEY -from tests_common.test_utils.compat import BaseOperatorLink from tests_common.test_utils.mock_operators import ( AirflowLink2, CustomOperator, @@ -208,7 +208,7 @@ "max_retry_delay": 600.0, "downstream_task_ids": [], "_is_empty": False, - "_operator_extra_links": [{"tests_common.test_utils.mock_operators.CustomOpLink": {}}], + "_operator_extra_links": {"Google Custom": "_link_CustomOpLink"}, "ui_color": "#fff", "ui_fgcolor": "#000", "template_ext": [], @@ -1101,16 +1101,13 @@ def test_task_params_roundtrip(self, val, expected_val): [ pytest.param( "true", - [{"tests_common.test_utils.mock_operators.CustomOpLink": {}}], + {"Google Custom": "_link_CustomOpLink"}, {"Google Custom": "http://google.com/custom_base_link?search=true"}, id="non-indexed-link", ), pytest.param( ["echo", "true"], - [ - {"tests_common.test_utils.mock_operators.CustomBaseIndexOpLink": {"index": 0}}, - {"tests_common.test_utils.mock_operators.CustomBaseIndexOpLink": {"index": 1}}, - ], + {"BigQuery Console #1": "bigquery_1", "BigQuery Console #2": "bigquery_2"}, { "BigQuery Console #1": "https://console.cloud.google.com/bigquery?j=echo", "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=true", @@ -1168,49 +1165,26 @@ def test_extra_serialized_field_and_operator_links( run_id=dr.run_id, ) + c = 0 # Test Deserialized inbuilt link for name, expected in links.items(): + # staging the part where a task at runtime pushes xcom for extra links + XCom.set( + key=simple_task.operator_extra_links[c].xcom_key, + value=expected, + task_id=simple_task.task_id, + dag_id=simple_task.dag_id, + run_id=dr.run_id, + ) + link = simple_task.get_extra_links(ti, name) assert link == expected + c += 1 # Test Deserialized link registered via Airflow Plugin link = simple_task.get_extra_links(ti, GoogleLink.name) assert link == "https://www.google.com" - @pytest.mark.usefixtures("clear_all_logger_handlers") - def test_extra_operator_links_logs_error_for_non_registered_extra_links(self): - """ - Assert OperatorLinks not registered via Plugins and if it is not an inbuilt Operator Link, - it can still deserialize the DAG (does not error) but just logs an error. - - We test NOT using caplog as this is flaky, we check that the task after deserialize - is missing the extra links. - """ - - class TaskStateLink(BaseOperatorLink): - """OperatorLink not registered via Plugins nor a built-in OperatorLink""" - - name = "My Link" - - def get_link(self, operator, *, ti_key): - return "https://www.google.com" - - class MyOperator(BaseOperator): - """Just a EmptyOperator using above defined Extra Operator Link""" - - operator_extra_links = [TaskStateLink()] - - def execute(self, context: Context): - pass - - with DAG(dag_id="simple_dag", schedule=None, start_date=datetime(2019, 8, 1)) as dag: - MyOperator(task_id="blah") - - serialized_dag = SerializedDAG.to_dict(dag) - - sdag = SerializedDAG.from_dict(serialized_dag) - assert sdag.task_dict["blah"].operator_extra_links == [] - class ClassWithCustomAttributes: """ Class for testing purpose: allows to create objects with custom attributes in one single statement. @@ -2980,7 +2954,7 @@ def operator_extra_links(self): "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", "downstream_task_ids": [], - "_operator_extra_links": [{"tests_common.test_utils.mock_operators.AirflowLink2": {}}], + "_operator_extra_links": {"airflow": "_link_AirflowLink2"}, "ui_color": "#fff", "ui_fgcolor": "#000", "template_ext": [], @@ -2994,4 +2968,7 @@ def operator_extra_links(self): "start_from_trigger": False, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) - assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()] + # operator defined links have to be instances of XComOperatorLink + assert deserialized_dag.task_dict["task"].operator_extra_links == [ + XComOperatorLink(name="airflow", xcom_key="_link_AirflowLink2") + ] diff --git a/tests/www/views/test_views_extra_links.py b/tests/www/views/test_views_extra_links.py index fcf119ba207a5..9e49eb7cd622c 100644 --- a/tests/www/views/test_views_extra_links.py +++ b/tests/www/views/test_views_extra_links.py @@ -20,11 +20,13 @@ import json import urllib.parse from unittest import mock +from unittest.mock import patch import pytest from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils import timezone from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -141,73 +143,83 @@ def _reset_task_instances(): def test_extra_links_works(dag_run, task_1, viewer_client, session): - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=foo-bar", - follow_redirects=True, - ) + expected_url = "http://www.example.com/some_dummy_task/foo-bar/manual__2017-01-01T00:00:00+00:00" + + with patch.object(task_1, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" + f"&logical_date={STR_DEFAULT_DATE}&link_name=foo-bar", + follow_redirects=True, + ) - assert response.status_code == 200 - assert json.loads(response.data.decode()) == { - "url": "http://www.example.com/some_dummy_task/foo-bar/manual__2017-01-01T00:00:00+00:00", - "error": None, - } + assert response.status_code == 200 + assert json.loads(response.data.decode()) == { + "url": expected_url, + "error": None, + } def test_global_extra_links_works(dag_run, task_1, viewer_client, session): - response = viewer_client.get( - f"{ENDPOINT}?dag_id={dag_run.dag_id}&task_id={task_1.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=github", - follow_redirects=True, - ) + expected_url = "https://github.com/apache/airflow" - assert response.status_code == 200 - assert json.loads(response.data.decode()) == { - "url": "https://github.com/apache/airflow", - "error": None, - } + with patch.object(task_1, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + f"{ENDPOINT}?dag_id={dag_run.dag_id}&task_id={task_1.task_id}" + f"&logical_date={STR_DEFAULT_DATE}&link_name=github", + follow_redirects=True, + ) + + assert response.status_code == 200 + assert json.loads(response.data.decode()) == { + "url": "https://github.com/apache/airflow", + "error": None, + } def test_operator_extra_link_override_global_extra_link(dag_run, task_1, viewer_client): - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=airflow", - follow_redirects=True, - ) + expected_url = "https://github.com/apache/airflow" + with patch.object(task_1, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" + f"&logical_date={STR_DEFAULT_DATE}&link_name=airflow", + follow_redirects=True, + ) - assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org", "error": None} + assert response.status_code == 200 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": expected_url, "error": None} def test_extra_links_error_raised(dag_run, task_1, viewer_client): - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=raise_error", - follow_redirects=True, - ) + with patch.object(task_1, "get_extra_links", side_effect=ValueError("This is an error"), create=True): + response = viewer_client.get( + f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" + f"&logical_date={STR_DEFAULT_DATE}&link_name=raise_error", + follow_redirects=True, + ) - assert response.status_code == 404 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": None, "error": "This is an error"} + assert response.status_code == 404 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": None, "error": "This is an error"} def test_extra_links_no_response(dag_run, task_1, viewer_client): - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=no_response", - follow_redirects=True, - ) + with patch.object(task_1, "get_extra_links", return_value=None, create=True): + response = viewer_client.get( + f"{ENDPOINT}?dag_id={task_1.dag_id}&task_id={task_1.task_id}" + f"&logical_date={STR_DEFAULT_DATE}&link_name=no_response", + follow_redirects=True, + ) - assert response.status_code == 404 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": None, "error": "No URL found for no_response"} + assert response.status_code == 404 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": None, "error": "No URL found for no_response"} def test_operator_extra_link_override_plugin(dag_run, task_2, viewer_client): @@ -218,17 +230,18 @@ def test_operator_extra_link_override_plugin(dag_run, task_2, viewer_client): AirflowLink returns 'https://airflow.apache.org/' link AirflowLink2 returns 'https://airflow.apache.org/1.10.5/' link """ - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_2.dag_id}&task_id={task_2.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=airflow", - follow_redirects=True, - ) - - assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + endpoint = f"{ENDPOINT}?dag_id={task_2.dag_id}&task_id={task_2.task_id}&logical_date={STR_DEFAULT_DATE}&link_name=airflow" + expected_url = get_extra_links_for_task_from_endpoint(task_2, endpoint) + with patch.object(task_2, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + endpoint, + follow_redirects=True, + ) + assert response.status_code == 200 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": expected_url, "error": None} def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_client): @@ -240,39 +253,63 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ AirflowLink2 returns 'https://airflow.apache.org/1.10.5/' link GoogleLink returns 'https://www.google.com' """ - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_2.dag_id}&task_id={task_2.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=airflow", - follow_redirects=True, - ) - assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + endpoint = f"{ENDPOINT}?dag_id={task_2.dag_id}&task_id={task_2.task_id}&logical_date={STR_DEFAULT_DATE}&link_name=airflow" + expected_url = get_extra_links_for_task_from_endpoint(task_2, endpoint) + with patch.object(task_2, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + endpoint, + follow_redirects=True, + ) - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_3.dag_id}&task_id={task_3.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=airflow", - follow_redirects=True, - ) + assert response.status_code == 200 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + + endpoint = f"{ENDPOINT}?dag_id={task_3.dag_id}&task_id={task_3.task_id}&logical_date={STR_DEFAULT_DATE}&link_name=airflow" + expected_url = get_extra_links_for_task_from_endpoint(task_3, endpoint) + with patch.object(task_3, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + endpoint, + follow_redirects=True, + ) - assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + assert response.status_code == 200 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} # Also check that the other Operator Link defined for this operator exists - response = viewer_client.get( - f"{ENDPOINT}?dag_id={task_3.dag_id}&task_id={task_3.task_id}" - f"&logical_date={STR_DEFAULT_DATE}&link_name=google", - follow_redirects=True, - ) + endpoint = f"{ENDPOINT}?dag_id={task_3.dag_id}&task_id={task_3.task_id}&logical_date={STR_DEFAULT_DATE}&link_name=google" + expected_url = get_extra_links_for_task_from_endpoint(task_3, endpoint) + with patch.object(task_3, "get_extra_links", return_value=expected_url, create=True): + response = viewer_client.get( + endpoint, + follow_redirects=True, + ) + + assert response.status_code == 200 + response_str = response.data + if isinstance(response.data, bytes): + response_str = response_str.decode() + assert json.loads(response_str) == {"url": "https://www.google.com", "error": None} + + +def convert_task_to_deser_task(task): + de = SerializedBaseOperator.deserialize_operator(SerializedBaseOperator.serialize_operator(task)) + return de + + +def get_extra_links_for_task_from_endpoint(task, endpoint): + import re + + match = re.search(r"[?&]link_name=([^&]+)", endpoint) + link_name = match.group(1) + de_task = convert_task_to_deser_task(task) - assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://www.google.com", "error": None} + for oe in de_task.operator_extra_links: + if oe.name == link_name: + return oe.get_link(operator=task, ti_key=None) diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index d4ed234725f00..a7bb4f6a0a68d 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -413,6 +413,7 @@ def pytest_configure(config: pytest.Config) -> None: "external_python_operator: external python operator tests are 'long', we should run them separately", ) config.addinivalue_line("markers", "enable_redact: do not mock redact secret masker") + config.addinivalue_line("markers", "mock_plugin_manager: mark a test to use mock_plugin_manager") os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1" @@ -1606,6 +1607,18 @@ def _disable_redact(request: pytest.FixtureRequest, mocker): return +@pytest.fixture(autouse=True) +def _mock_plugins(request: pytest.FixtureRequest): + """Disable redacted text in tests, except specific.""" + if mark := next(request.node.iter_markers("mock_plugin_manager"), None): + from tests_common.test_utils.mock_plugins import mock_plugin_manager + + with mock_plugin_manager(**mark.kwargs): + yield + return + yield + + @pytest.fixture def providers_src_folder() -> Path: import airflow.providers diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index 81f53abf648ce..bff61f33a27e7 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -85,6 +85,10 @@ class CustomBaseIndexOpLink(BaseOperatorLink): def name(self) -> str: return f"BigQuery Console #{self.index + 1}" + @property + def xcom_key(self) -> str: + return f"bigquery_{self.index + 1}" + def get_link(self, operator, *, ti_key): search_queries = XCom.get_one( task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query"