diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index e9574a359266f..22f32d8405199 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1101,6 +1101,7 @@ muldelete Multinamespace mutex mv +mwaa mypy Mysql mysql diff --git a/providers/amazon/docs/operators/mwaa.rst b/providers/amazon/docs/operators/mwaa.rst new file mode 100644 index 0000000000000..021998b0a10ed --- /dev/null +++ b/providers/amazon/docs/operators/mwaa.rst @@ -0,0 +1,64 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +================================================== +Amazon Managed Workflows for Apache Airflow (MWAA) +================================================== + +`Amazon Managed Workflows for Apache Airflow (MWAA) `__ +is a managed service for Apache Airflow that lets you use your current, familiar Apache Airflow platform to orchestrate +your workflows. You gain improved scalability, availability, and security without the operational burden of managing +underlying infrastructure. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:MwaaTriggerDagRunOperator: + +Trigger a DAG run in an Amazon MWAA environment +=============================================== + +To trigger a DAG run in an Amazon MWAA environment you can use the +:class:`~airflow.providers.amazon.aws.operators.mwaa.MwaaTriggerDagRunOperator` + +Note: Unlike :class:`~airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunOperator`, this operator is capable of +triggering a DAG in a separate Airflow environment as long as the environment with the DAG being triggered is running on +AWS MWAA. + +In the following example, the task ``trigger_dag_run`` triggers a dag run for a DAG with with the ID ``hello_world`` in +the environment ``MyAirflowEnvironment``. + +.. exampleinclude:: /../../providers/amazon/tests/system/amazon/aws/example_mwaa.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_mwaa_trigger_dag_run] + :end-before: [END howto_operator_mwaa_trigger_dag_run] + +References +---------- + +* `AWS boto3 library documentation for MWAA `__ diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 3310ec331e2de..b96f9ea8ebf6e 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -392,6 +392,9 @@ operators: - integration-name: Amazon Managed Service for Apache Flink python-modules: - airflow.providers.amazon.aws.operators.kinesis_analytics + - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) + python-modules: + - airflow.providers.amazon.aws.operators.mwaa - integration-name: Amazon Simple Storage Service (S3) python-modules: - airflow.providers.amazon.aws.operators.s3 @@ -600,6 +603,9 @@ hooks: - integration-name: Amazon CloudWatch Logs python-modules: - airflow.providers.amazon.aws.hooks.logs + - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA) + python-modules: + - airflow.providers.amazon.aws.hooks.mwaa - integration-name: Amazon OpenSearch Serverless python-modules: - airflow.providers.amazon.aws.hooks.opensearch_serverless diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py new file mode 100644 index 0000000000000..d7f01238e6ab8 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS MWAA hook.""" + +from __future__ import annotations + +from botocore.exceptions import ClientError + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class MwaaHook(AwsBaseHook): + """ + Interact with AWS Manager Workflows for Apache Airflow. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "mwaa" + super().__init__(*args, **kwargs) + + def invoke_rest_api( + self, + env_name: str, + path: str, + method: str, + body: dict | None = None, + query_params: dict | None = None, + ) -> dict: + """ + Invoke the REST API on the Airflow webserver with the specified inputs. + + .. seealso:: + - :external+boto3:py:meth:`MWAA.Client.invoke_rest_api` + + :param env_name: name of the MWAA environment + :param path: Apache Airflow REST API endpoint path to be called + :param method: HTTP method used for making Airflow REST API calls + :param body: Request body for the Apache Airflow REST API call + :param query_params: Query parameters to be included in the Apache Airflow REST API call + """ + body = body or {} + api_kwargs = { + "Name": env_name, + "Path": path, + "Method": method, + # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise + "Body": {k: v for k, v in body.items() if v is not None}, + "QueryParameters": query_params if query_params else {}, + } + try: + result = self.conn.invoke_rest_api(**api_kwargs) + # ResponseMetadata is removed because it contains data that is either very unlikely to be useful + # in XComs and logs, or redundant given the data already included in the response + result.pop("ResponseMetadata", None) + return result + except ClientError as e: + to_log = e.response + # ResponseMetadata and Error are removed because they contain data that is either very unlikely to + # be useful in XComs and logs, or redundant given the data already included in the response + to_log.pop("ResponseMetadata", None) + to_log.pop("Error", None) + self.log.error(to_log) + raise e diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py new file mode 100644 index 0000000000000..42f1038f2c5cb --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains AWS MWAA operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]): + """ + Trigger a Dag Run for a Dag in an Amazon MWAA environment. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:MwaaTriggerDagRunOperator` + + :param env_name: The MWAA environment name (templated) + :param trigger_dag_id: The ID of the DAG to be triggered (templated) + :param trigger_run_id: The Run ID. This together with trigger_dag_id are a unique key. (templated) + :param logical_date: The logical date (previously called execution date). This is the time or interval + covered by this DAG run, according to the DAG definition. This together with trigger_dag_id are a + unique key. (templated) + :param data_interval_start: The beginning of the interval the DAG run covers + :param data_interval_end: The end of the interval the DAG run covers + :param conf: Additional configuration parameters. The value of this field can be set only when creating + the object. (templated) + :param note: Contains manually entered notes by the user about the DagRun. (templated) + """ + + aws_hook_class = MwaaHook + template_fields: Sequence[str] = aws_template_fields( + "env_name", + "trigger_dag_id", + "trigger_run_id", + "logical_date", + "data_interval_start", + "data_interval_end", + "conf", + "note", + ) + template_fields_renderers = {"conf": "json"} + + def __init__( + self, + *, + env_name: str, + trigger_dag_id: str, + trigger_run_id: str | None = None, + logical_date: str | None = None, + data_interval_start: str | None = None, + data_interval_end: str | None = None, + conf: dict | None = None, + note: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.env_name = env_name + self.trigger_dag_id = trigger_dag_id + self.trigger_run_id = trigger_run_id + self.logical_date = logical_date + self.data_interval_start = data_interval_start + self.data_interval_end = data_interval_end + self.conf = conf if conf else {} + self.note = note + + def execute(self, context: Context) -> dict: + """ + Trigger a Dag Run for the Dag in the Amazon MWAA environment. + + :param context: the Context object + :return: dict with information about the Dag run + For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api` + """ + return self.hook.invoke_rest_api( + env_name=self.env_name, + path=f"/dags/{self.trigger_dag_id}/dagRuns", + method="POST", + body={ + "dag_run_id": self.trigger_run_id, + "logical_date": self.logical_date, + "data_interval_start": self.data_interval_start, + "data_interval_end": self.data_interval_end, + "conf": self.conf, + "note": self.note, + }, + ) diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 217f618c8667c..e6e8020185633 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -478,6 +478,10 @@ def get_provider_info(): "integration-name": "Amazon Managed Service for Apache Flink", "python-modules": ["airflow.providers.amazon.aws.operators.kinesis_analytics"], }, + { + "integration-name": "Amazon Managed Workflows for Apache Airflow (MWAA)", + "python-modules": ["airflow.providers.amazon.aws.operators.mwaa"], + }, { "integration-name": "Amazon Simple Storage Service (S3)", "python-modules": ["airflow.providers.amazon.aws.operators.s3"], @@ -747,6 +751,10 @@ def get_provider_info(): "integration-name": "Amazon CloudWatch Logs", "python-modules": ["airflow.providers.amazon.aws.hooks.logs"], }, + { + "integration-name": "Amazon Managed Workflows for Apache Airflow (MWAA)", + "python-modules": ["airflow.providers.amazon.aws.hooks.mwaa"], + }, { "integration-name": "Amazon OpenSearch Serverless", "python-modules": ["airflow.providers.amazon.aws.hooks.opensearch_serverless"], diff --git a/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py new file mode 100644 index 0000000000000..5d8dc761c3334 --- /dev/null +++ b/providers/amazon/tests/provider_tests/amazon/aws/hooks/test_mwaa.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from botocore.exceptions import ClientError +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook + +ENV_NAME = "test_env" +PATH = "/dags/test_dag/dagRuns" +METHOD = "POST" +QUERY_PARAMS = {"limit": 30} + + +class TestMwaaHook: + def setup_method(self): + self.hook = MwaaHook() + + # these example responses are included here instead of as a constant because the hook will mutate + # responses causing subsequent tests to fail + self.example_responses = { + "success": { + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 200, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + "RestApiStatusCode": 200, + "RestApiResponse": { + "conf": {}, + "dag_id": "hello_world", + "dag_run_id": "manual__2025-02-08T00:33:09.457198+00:00", + "data_interval_end": "2025-02-08T00:33:09.457198+00:00", + "data_interval_start": "2025-02-08T00:33:09.457198+00:00", + "execution_date": "2025-02-08T00:33:09.457198+00:00", + "external_trigger": True, + "logical_date": "2025-02-08T00:33:09.457198+00:00", + "run_type": "manual", + "state": "queued", + }, + }, + "failure": { + "Error": {"Message": "", "Code": "RestApiClientException"}, + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 400, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + "RestApiStatusCode": 404, + "RestApiResponse": { + "detail": "DAG with dag_id: 'hello_world1' not found", + "status": 404, + "title": "DAG not found", + "type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound", + }, + }, + } + + def test_init(self): + assert self.hook.client_type == "mwaa" + + @mock_aws + def test_get_conn(self): + assert self.hook.conn is not None + + @pytest.mark.parametrize( + "body", + [ + pytest.param(None, id="no_body"), + pytest.param({"conf": {}}, id="non_empty_body"), + ], + ) + @mock.patch.object(MwaaHook, "conn") + def test_invoke_rest_api_success(self, mock_conn, body) -> None: + boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"]) + mock_conn.invoke_rest_api = boto_invoke_mock + + retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS) + kwargs_to_assert = { + "Name": ENV_NAME, + "Path": PATH, + "Method": METHOD, + "Body": body if body else {}, + "QueryParameters": QUERY_PARAMS, + } + boto_invoke_mock.assert_called_once_with(**kwargs_to_assert) + assert retval == { + k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata" + } + + @mock.patch.object(MwaaHook, "conn") + def test_invoke_rest_api_failure(self, mock_conn) -> None: + error = ClientError( + error_response=self.example_responses["failure"], operation_name="invoke_rest_api" + ) + boto_invoke_mock = mock.MagicMock(side_effect=error) + mock_conn.invoke_rest_api = boto_invoke_mock + mock_log = mock.MagicMock() + self.hook.log.error = mock_log + + with pytest.raises(ClientError) as caught_error: + self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD) + + assert caught_error.value == error + expected_log = { + k: v + for k, v in self.example_responses["failure"].items() + if k != "ResponseMetadata" and k != "Error" + } + mock_log.assert_called_once_with(expected_log) diff --git a/providers/amazon/tests/provider_tests/amazon/aws/operators/test_mwaa.py b/providers/amazon/tests/provider_tests/amazon/aws/operators/test_mwaa.py new file mode 100644 index 0000000000000..27c11a95d4230 --- /dev/null +++ b/providers/amazon/tests/provider_tests/amazon/aws/operators/test_mwaa.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator +from provider_tests.amazon.aws.utils.test_template_fields import validate_template_fields + +OP_KWARGS = { + "task_id": "test_task", + "env_name": "test_env", + "trigger_dag_id": "test_dag_id", + "trigger_run_id": "test_run_id", + "logical_date": "2025-01-01T00:00:01Z", + "data_interval_start": "2025-01-02T00:00:01Z", + "data_interval_end": "2025-01-03T00:00:01Z", + "conf": {"key": "value"}, + "note": "test note", +} +HOOK_RETURN_VALUE = { + "ResponseMetadata": {}, + "RestApiStatusCode": 200, + "RestApiResponse": { + "dag_run_id": "manual__2025-02-08T00:33:09.457198+00:00", + "other_key": "value", + }, +} + + +class TestMwaaTriggerDagRunOperator: + def test_init(self): + op = MwaaTriggerDagRunOperator(**OP_KWARGS) + assert op.env_name == OP_KWARGS["env_name"] + assert op.trigger_dag_id == OP_KWARGS["trigger_dag_id"] + assert op.trigger_run_id == OP_KWARGS["trigger_run_id"] + assert op.logical_date == OP_KWARGS["logical_date"] + assert op.data_interval_start == OP_KWARGS["data_interval_start"] + assert op.data_interval_end == OP_KWARGS["data_interval_end"] + assert op.conf == OP_KWARGS["conf"] + assert op.note == OP_KWARGS["note"] + + @mock.patch.object(MwaaTriggerDagRunOperator, "hook") + def test_execute(self, mock_hook): + mock_hook.invoke_rest_api.return_value = HOOK_RETURN_VALUE + op = MwaaTriggerDagRunOperator(**OP_KWARGS) + op_ret_val = op.execute({}) + + mock_hook.invoke_rest_api.assert_called_once_with( + env_name=OP_KWARGS["env_name"], + path=f"/dags/{OP_KWARGS['trigger_dag_id']}/dagRuns", + method="POST", + body={ + "dag_run_id": OP_KWARGS["trigger_run_id"], + "logical_date": OP_KWARGS["logical_date"], + "data_interval_start": OP_KWARGS["data_interval_start"], + "data_interval_end": OP_KWARGS["data_interval_end"], + "conf": OP_KWARGS["conf"], + "note": OP_KWARGS["note"], + }, + ) + assert op_ret_val == HOOK_RETURN_VALUE + + def test_template_fields(self): + operator = MwaaTriggerDagRunOperator(**OP_KWARGS) + validate_template_fields(operator) diff --git a/providers/amazon/tests/system/amazon/aws/example_mwaa.py b/providers/amazon/tests/system/amazon/aws/example_mwaa.py new file mode 100644 index 0000000000000..5f178019f64e1 --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_mwaa.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from providers.amazon.tests.system.amazon.aws.utils import SystemTestContextBuilder + +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator + +DAG_ID = "example_mwaa" + +# Externally fetched variables: +EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME" +EXISTING_DAG_ID_KEY = "DAG_ID" + + +sys_test_context_task = ( + SystemTestContextBuilder() + # NOTE: Creating a functional MWAA environment is time-consuming and requires + # manually creating and configuring an S3 bucket for DAG storage and a VPC with + # private subnets which is out of scope for this demo. To simplify this demo and + # make it run in a reasonable time, an existing MWAA environment already + # containing a DAG is required. + # Here's a quick start guide to create an MWAA environment using AWS CloudFormation: + # https://docs.aws.amazon.com/mwaa/latest/userguide/quick-start.html + # If creating the environment using the AWS console, make sure to have a VPC with + # at least 1 private subnet to be able to select the VPC while going through the + # environment creation steps in the console wizard. + # Make sure to set the environment variables with appropriate values + .add_variable(EXISTING_ENVIRONMENT_NAME_KEY) + .add_variable(EXISTING_DAG_ID_KEY) + .build() +) + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_name = test_context[EXISTING_ENVIRONMENT_NAME_KEY] + trigger_dag_id = test_context[EXISTING_DAG_ID_KEY] + + # [START howto_operator_mwaa_trigger_dag_run] + trigger_dag_run = MwaaTriggerDagRunOperator( + task_id="trigger_dag_run", + env_name=env_name, + trigger_dag_id=trigger_dag_id, + ) + # [END howto_operator_mwaa_trigger_dag_run] + + chain( + # TEST SETUP + test_context, + # TEST BODY + trigger_dag_run, + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)