From 9689cf5bb9d53be2238456fc138e7bf7f5e62e33 Mon Sep 17 00:00:00 2001 From: Josix Date: Fri, 7 Feb 2025 11:23:32 +0800 Subject: [PATCH] refactor(utils/decorators): rewrite remove task decorator to use cst (#43383) * refactor(utils/decorators): rewrite remove task decorator to use ast * Update airflow/utils/decorators.py Co-authored-by: Wei Lee * Update airflow/utils/decorators.py Co-authored-by: Wei Lee * Update airflow/utils/decorators.py Co-authored-by: Wei Lee * fixup! refactor(utils/decorators): rewrite remove task decorator to use ast * fixup! refactor(utils/decorators): rewrite remove task decorator to use ast * Update providers/standard/tests/provider_tests/standard/utils/test_python_virtualenv.py Co-authored-by: Wei Lee * Update airflow/utils/decorators.py Co-authored-by: Wei Lee * fixup! refactor(utils/decorators): rewrite remove task decorator to use ast * fixup! refactor(utils/decorators): rewrite remove task decorator to use ast --------- Co-authored-by: Wei Lee --- airflow/utils/decorators.py | 70 +++++++++---------- hatch_build.py | 1 + .../standard/utils/test_python_virtualenv.py | 46 ++++++------ ...preexisting_python_virtualenv_decorator.py | 43 +++++++----- 4 files changed, 87 insertions(+), 73 deletions(-) diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 78044e4e35761..69475dda84349 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -18,54 +18,52 @@ from __future__ import annotations import sys -from collections import deque from typing import Callable, TypeVar +import libcst as cst + T = TypeVar("T", bound=Callable) +class _TaskDecoratorRemover(cst.CSTTransformer): + def __init__(self, task_decorator_name: str) -> None: + self.decorators_to_remove: set[str] = { + "setup", + "teardown", + "task.skip_if", + "task.run_if", + task_decorator_name.strip("@"), + } + + def _is_task_decorator(self, decorator_node: cst.Decorator) -> bool: + decorator_expr = decorator_node.decorator + if isinstance(decorator_expr, cst.Name): + return decorator_expr.value in self.decorators_to_remove + elif isinstance(decorator_expr, cst.Attribute) and isinstance(decorator_expr.value, cst.Name): + return f"{decorator_expr.value.value}.{decorator_expr.attr.value}" in self.decorators_to_remove + elif isinstance(decorator_expr, cst.Call): + return self._is_task_decorator(cst.Decorator(decorator=decorator_expr.func)) + return False + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)] + if len(new_decorators) == len(updated_node.decorators): + return updated_node + return updated_node.with_changes(decorators=new_decorators) + + def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: """ Remove @task or similar decorators as well as @setup and @teardown. :param python_source: The python source code :param task_decorator_name: the decorator name - - TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse """ - - def _remove_task_decorator(py_source, decorator_name): - # if no line starts with @decorator_name, we can early exit - for line in py_source.split("\n"): - if line.startswith(decorator_name): - break - else: - return python_source - split = python_source.split(decorator_name, 1) - before_decorator, after_decorator = split[0], split[1] - if after_decorator[0] == "(": - after_decorator = _balance_parens(after_decorator) - if after_decorator[0] == "\n": - after_decorator = after_decorator[1:] - return before_decorator + after_decorator - - decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] - for decorator in decorators: - python_source = _remove_task_decorator(python_source, decorator) - return python_source - - -def _balance_parens(after_decorator): - num_paren = 1 - after_decorator = deque(after_decorator) - after_decorator.popleft() - while num_paren: - current = after_decorator.popleft() - if current == "(": - num_paren = num_paren + 1 - elif current == ")": - num_paren = num_paren - 1 - return "".join(after_decorator) + source_tree = cst.parse_module(python_source) + modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name)) + return modified_tree.code class _autostacklevel_warn: diff --git a/hatch_build.py b/hatch_build.py index 938ffd001c4d0..1a938e5f8b31b 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -391,6 +391,7 @@ "jinja2>=3.0.0", "jsonschema>=4.18.0", "lazy-object-proxy>=1.2.0", + "libcst >=1.1.0", "linkify-it-py>=2.0.0", "lockfile>=0.12.2", "markdown-it-py>=2.1.0", diff --git a/providers/standard/tests/provider_tests/standard/utils/test_python_virtualenv.py b/providers/standard/tests/provider_tests/standard/utils/test_python_virtualenv.py index d1a7eef94ec26..6e571962befbc 100644 --- a/providers/standard/tests/provider_tests/standard/utils/test_python_virtualenv.py +++ b/providers/standard/tests/provider_tests/standard/utils/test_python_virtualenv.py @@ -18,6 +18,7 @@ from __future__ import annotations from pathlib import Path +from textwrap import dedent from unittest import mock import pytest @@ -191,26 +192,29 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s ["uv", "pip", "install", "--python", "/VENV/bin/python", "apache-beam[gcp]"] ) - def test_remove_task_decorator(self): - py_source = '@task.virtualenv(serializer="dill")\ndef f():\nimport funcsigs' - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\nimport funcsigs" - - def test_remove_decorator_no_parens(self): - py_source = "@task.virtualenv\ndef f():\nimport funcsigs" - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\nimport funcsigs" - - def test_remove_decorator_including_comment(self): - py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs" - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\n# @task.virtualenv\nimport funcsigs" - - def test_remove_decorator_nested(self): - py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs" - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + @pytest.mark.parametrize( + "decorators, expected_decorators", + [ + (["@task.virtualenv"], []), + (["@task.virtualenv()"], []), + (['@task.virtualenv(serializer="dill")'], []), + (["@foo", "@task.virtualenv", "@bar"], ["@foo", "@bar"]), + (["@foo", "@task.virtualenv()", "@bar"], ["@foo", "@bar"]), + ], + ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"], + ) + def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]): + concated_decorators = "\n".join(decorators) + expected_decorator = "\n".join(expected_decorators) + SCRIPT = dedent( + """ + def f(): + # @task.virtualenv + import funcsigs + """ + ) + py_source = concated_decorators + SCRIPT + expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip() - py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == expected_source diff --git a/tests/utils/test_preexisting_python_virtualenv_decorator.py b/tests/utils/test_preexisting_python_virtualenv_decorator.py index 11d80e348ea81..a87398ee745b2 100644 --- a/tests/utils/test_preexisting_python_virtualenv_decorator.py +++ b/tests/utils/test_preexisting_python_virtualenv_decorator.py @@ -17,25 +17,36 @@ # under the License. from __future__ import annotations -from airflow.utils.decorators import remove_task_decorator +from textwrap import dedent +import pytest -class TestExternalPythonDecorator: - def test_remove_task_decorator(self): - py_source = '@task.external_python(serializer="dill")\ndef f():\nimport funcsigs' - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "def f():\nimport funcsigs" +from airflow.utils.decorators import remove_task_decorator - def test_remove_decorator_no_parens(self): - py_source = "@task.external_python\ndef f():\nimport funcsigs" - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "def f():\nimport funcsigs" - def test_remove_decorator_nested(self): - py_source = "@foo\n@task.external_python\n@bar\ndef f():\nimport funcsigs" - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" +class TestExternalPythonDecorator: + @pytest.mark.parametrize( + "decorators, expected_decorators", + [ + (["@task.external_python"], []), + (["@task.external_python()"], []), + (['@task.external_python(serializer="dill")'], []), + (["@foo", "@task.external_python", "@bar"], ["@foo", "@bar"]), + (["@foo", "@task.external_python()", "@bar"], ["@foo", "@bar"]), + ], + ids=["without_parens", "parens", "with_args", "nested_without_parens", "nested_with_parens"], + ) + def test_remove_task_decorator(self, decorators: list[str], expected_decorators: list[str]): + concated_decorators = "\n".join(decorators) + expected_decorator = "\n".join(expected_decorators) + SCRIPT = dedent( + """ + def f(): + import funcsigs + """ + ) + py_source = concated_decorators + SCRIPT + expected_source = expected_decorator + SCRIPT if expected_decorator else SCRIPT.lstrip() - py_source = "@foo\n@task.external_python()\n@bar\ndef f():\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == expected_source