Skip to content

Commit

Permalink
AIP-72: Improving Operator Links Interface to Prevent User Code Execu…
Browse files Browse the repository at this point in the history
…tion in Webserver (apache#46613)

Operator Links interface changed to not run user code in Airflow Webserver The Operator Extra links, which can be defined either via plugins or custom operators now do not execute any user code in the Airflow Webserver, but instead push the "full" links to XCom backend and the value is again fetched from the XCom backend when viewing task details in grid view.


Example:
```
@attr.s(auto_attribs=True)
class CustomBaseIndexOpLink(BaseOperatorLink):
    """Custom Operator Link for Google BigQuery Console."""
    index: int = attr.ib()
    @Property
    def name(self) -> str:
        return f"BigQuery Console #{self.index + 1}"

    @Property
    def xcom_key(self) -> str:
        return f"bigquery_{self.index + 1}"

    def get_link(self, operator, *, ti_key):
        search_queries = XCom.get_one(
            task_id=ti_key.task_id, dag_id=ti_key.dag_id, run_id=ti_key.run_id, key="search_query"
        )
        if not search_queries:
            return None
        if len(search_queries) < self.index:
            return None
        search_query = search_queries[self.index]
        return f"https://console.cloud.google.com/bigquery?j={search_query}"
```
  • Loading branch information
amoghrajesh authored Feb 13, 2025
1 parent 3b72068 commit fe5a2ea
Show file tree
Hide file tree
Showing 20 changed files with 465 additions and 420 deletions.
62 changes: 0 additions & 62 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
from __future__ import annotations

import datetime
import inspect
from collections.abc import Iterable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

from sqlalchemy import select

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand All @@ -42,7 +39,6 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
Expand Down Expand Up @@ -157,64 +153,6 @@ def priority_weight_total(self) -> int:
)
)

@cached_property
def operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all extra links for the operator."""
op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can't load operators")
for ope in plugins_manager.operator_extra_links:
if ope.operators and self.operator_class in ope.operators:
op_extra_links_from_plugin.update({ope.name: ope})

operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
# Extra links defined in Plugins overrides operator links defined in operator
operator_extra_links_all.update(op_extra_links_from_plugin)

return operator_extra_links_all

@cached_property
def global_operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all global extra links."""
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.global_operator_extra_links is None:
raise AirflowException("Can't load operators")
return {link.name: link for link in plugins_manager.global_operator_extra_links}

@cached_property
def extra_links(self) -> list[str]:
return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))

def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
"""
For an operator, gets the URLs that the ``extra_links`` entry points to.
:meta private:
:raise ValueError: The error message of a ValueError will be passed on through to
the fronted to show up as a tooltip on the disabled link.
:param ti: The TaskInstance for the URL being searched for.
:param link_name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name)
if not link:
link = self.global_operator_extra_link_dict.get(link_name)
if not link:
return None

parameters = inspect.signature(link.get_link).parameters
old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD)

if old_signature:
return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
return link.get_link(self.unmap(None), ti_key=ti.key)

def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand Down
54 changes: 52 additions & 2 deletions airflow/models/baseoperatorlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,53 @@
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, ClassVar

import attr
import attrs

from airflow.models.xcom import BaseXCom
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey


@attr.s(auto_attribs=True)
@attrs.define()
class XComOperatorLink(LoggingMixin):
"""A generic operator link class that can retrieve link only using XCOMs. Used while deserializing operators."""

name: str
xcom_key: str

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
"""
Retrieve the link from the XComs.
:param operator: The Airflow operator object this link is associated to.
:param ti_key: TaskInstance ID to return link for.
:return: link to external system, but by pulling it from XComs
"""
self.log.info(
"Attempting to retrieve link from XComs with key: %s for task id: %s", self.xcom_key, ti_key
)
value = BaseXCom.get_one(
key=self.xcom_key,
run_id=ti_key.run_id,
dag_id=ti_key.dag_id,
task_id=ti_key.task_id,
map_index=ti_key.map_index,
)
if not value:
self.log.debug(
"No link with name: %s present in XCom as key: %s, returning empty link",
self.name,
self.xcom_key,
)
return ""
# Stripping is a temporary workaround till https://github.com/apache/airflow/issues/46513 is handled.
return value.strip('"')


@attrs.define()
class BaseOperatorLink(metaclass=ABCMeta):
"""Abstract base class that defines how we get an operator link."""

Expand All @@ -44,6 +83,17 @@ class BaseOperatorLink(metaclass=ABCMeta):
def name(self) -> str:
"""Name of the link. This will be the button name on the task UI."""

@property
def xcom_key(self) -> str:
"""
XCom key with while the whole "link" for this operator link is stored.
On retrieving with this key, the entire link is returned.
Defaults to `_link_<class name>` if not provided.
"""
return f"_link_{self.__class__.__name__}"

@abstractmethod
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
"""
Expand Down
137 changes: 77 additions & 60 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
from airflow.exceptions import AirflowException, SerializationError, TaskDeferred
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink, XComOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, _get_model_data_interval
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
create_expand_input,
)
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
from airflow.providers_manager import ProvidersManager
Expand Down Expand Up @@ -96,7 +97,6 @@
from inspect import Parameter

from airflow.models import DagRun
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.expandinput import ExpandInput
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.types import Operator
Expand Down Expand Up @@ -1167,6 +1167,58 @@ def __init__(self, *args, **kwargs):
self.template_fields = BaseOperator.template_fields
self.operator_extra_links = BaseOperator.operator_extra_links

@cached_property
def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]:
"""Returns dictionary of all extra links for the operator."""
op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.operator_extra_links is None:
raise AirflowException("Can't load operators")
for ope in plugins_manager.operator_extra_links:
if ope.operators and self.operator_class in ope.operators:
op_extra_links_from_plugin.update({ope.name: ope})

operator_extra_links_all = {link.name: link for link in self.operator_extra_links}
# Extra links defined in Plugins overrides operator links defined in operator
operator_extra_links_all.update(op_extra_links_from_plugin)

return operator_extra_links_all

@cached_property
def global_operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all global extra links."""
from airflow import plugins_manager

plugins_manager.initialize_extra_operators_links_plugins()
if plugins_manager.global_operator_extra_links is None:
raise AirflowException("Can't load operators")
return {link.name: link for link in plugins_manager.global_operator_extra_links}

@cached_property
def extra_links(self) -> list[str]:
return sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))

def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
"""
For an operator, gets the URLs that the ``extra_links`` entry points to.
:meta private:
:raise ValueError: The error message of a ValueError will be passed on through to
the fronted to show up as a tooltip on the disabled link.
:param ti: The TaskInstance for the URL being searched for.
:param link_name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
link = self.operator_extra_link_dict.get(link_name)
if not link:
link = self.global_operator_extra_link_dict.get(link_name)
if not link:
return None
return link.get_link(self.unmap(None), ti_key=ti.key)

@property
def task_type(self) -> str:
# Overwrites task_type of BaseOperator to use _task_type instead of
Expand Down Expand Up @@ -1504,7 +1556,9 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
return super()._is_excluded(var, attrname, op)

@classmethod
def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
def _deserialize_operator_extra_links(
cls, encoded_op_links: dict[str, str]
) -> dict[str, XComOperatorLink]:
"""
Deserialize Operator Links if the Classes are registered in Airflow Plugins.
Expand All @@ -1521,77 +1575,40 @@ def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str,
raise AirflowException("Can't load plugins")
op_predefined_extra_links = {}

for _operator_links_source in encoded_op_links:
# Get the key, value pair as Tuple where key is OperatorLink ClassName
# and value is the dictionary containing the arguments passed to the OperatorLink
#
# Example of a single iteration:
#
# _operator_links_source =
# {
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
# 'index': 0
# }
# },
#
# list(_operator_links_source.items()) =
# [
# (
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
# {'index': 0}
# )
# ]
for name, xcom_key in encoded_op_links.items():
# Get the name and xcom_key of the encoded operator and use it to create a XComOperatorLink object
# during deserialization.
#
# list(_operator_links_source.items())[0] =
# (
# 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
# {
# 'index': 0
# }
# )

_operator_link_class_path, data = next(iter(_operator_links_source.items()))
if _operator_link_class_path in get_operator_extra_links():
single_op_link_class = import_string(_operator_link_class_path)
elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
single_op_link_class = plugins_manager.registered_operator_link_classes[
_operator_link_class_path
]
else:
log.error("Operator Link class %r not registered", _operator_link_class_path)
return {}

op_link_parameters = {param: cls.deserialize(value) for param, value in data.items()}
op_predefined_extra_link: BaseOperatorLink = single_op_link_class(**op_link_parameters)

# Example:
# enc_operator['_operator_extra_links'] =
# {
# 'airflow': 'airflow_link_key',
# 'foo-bar': 'link-key',
# 'no_response': 'key',
# 'raise_error': 'key'
# }

op_predefined_extra_link = XComOperatorLink(name=name, xcom_key=xcom_key)
op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})

return op_predefined_extra_links

@classmethod
def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
def _serialize_operator_extra_links(
cls, operator_extra_links: Iterable[BaseOperatorLink]
) -> dict[str, str]:
"""
Serialize Operator Links.
Store the import path of the OperatorLink and the arguments passed to it.
Store the "name" of the link mapped with the xcom_key which can be later used to retrieve this
operator extra link from XComs.
For example:
``[{'airflow.providers.google.cloud.links.bigquery.BigQueryDatasetLink': {}}]``
``{'link-name-1': 'xcom-key-1'}``
:param operator_extra_links: Operator Link
:return: Serialized Operator Link
"""
serialize_operator_extra_links = []
for operator_extra_link in operator_extra_links:
op_link_arguments = {
param: cls.serialize(value) for param, value in attrs.asdict(operator_extra_link).items()
}

module_path = (
f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
)
serialize_operator_extra_links.append({module_path: op_link_arguments})

return serialize_operator_extra_links
return {link.name: link.xcom_key for link in operator_extra_links}

@classmethod
def serialize(cls, var: Any, *, strict: bool = False) -> Any:
Expand Down
1 change: 1 addition & 0 deletions newsfragments/46613.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Operator Links interface changed to not run user code in Airflow Webserver The Operator Extra links, which can be defined either via plugins or custom operators now do not execute any user code in the Airflow Webserver, but instead push the "full" links to XCom backend and the value is again fetched from the XCom backend when viewing task details in grid view.
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,7 @@ def assert_extra_link_url(
)

error_msg = f"{self.full_qualname!r} should be preserved after execution"
assert ti.task.get_extra_links(ti, self.link_class.name) == expected_url, error_msg

serialized_dag = self.dag_maker.get_serialized_data()
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.task_dict[self.task_id]

error_msg = f"{self.full_qualname!r} should be preserved in deserialized tasks after execution"
assert deserialized_task.get_extra_links(ti, self.link_class.name) == expected_url, error_msg
assert task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) == expected_url, error_msg

def test_link_serialize(self):
"""Test: Operator links should exist for serialized DAG."""
Expand All @@ -223,7 +216,7 @@ def test_empty_xcom(self):
deserialized_task = deserialized_dag.task_dict[self.task_id]

assert (
ti.task.get_extra_links(ti, self.link_class.name) == ""
ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key) == ""
), "Operator link should only be added if job id is available in XCom"

assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def test_run_job_operator_link(self, conn_id, account_id, create_task_instance_o

ti.xcom_push(key="job_run_url", value=_run_response["data"]["href"])

url = ti.task.get_extra_links(ti, "Monitor Job Run")
url = ti.task.operator_extra_links[0].get_link(operator=ti.task, ti_key=ti.key)

assert url == (
EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
Expand Down
Loading

0 comments on commit fe5a2ea

Please sign in to comment.