Skip to content

Commit

Permalink
AIP-72: Add a basic test for a task run (#44203)
Browse files Browse the repository at this point in the history
This PR adds a very basic test to parse & run. We will start adding more things here and porting things from core.
  • Loading branch information
kaxil authored Nov 20, 2024
1 parent 175b960 commit 8554001
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 5 deletions.
7 changes: 4 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
class CommsDecoder:
"""Handle communication between the task in this process and the supervisor parent process."""

input: TextIO = sys.stdin
input: TextIO

request_socket: FileIO = attrs.field(init=False, default=None)

Expand Down Expand Up @@ -164,7 +164,8 @@ def run(ti: RuntimeTaskInstance, log: Logger):
except SystemExit:
...
except BaseException:
...
# TODO: Handle TI handle failure
raise


def finalize(log: Logger): ...
Expand All @@ -174,7 +175,7 @@ def main():
# TODO: add an exception here, it causes an oof of a stack trace!

global SUPERVISOR_COMMS
SUPERVISOR_COMMS = CommsDecoder()
SUPERVISOR_COMMS = CommsDecoder(input=sys.stdin)
try:
ti, log = startup()
run(ti, log)
Expand Down
35 changes: 35 additions & 0 deletions task_sdk/tests/dags/super_basic_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import dag


class CustomOperator(BaseOperator):
def execute(self, context):
task_id = context["task_instance"].task_id
print(f"Hello World {task_id}!")


@dag()
def super_basic_run():
CustomOperator(task_id="hello")


super_basic_run()
36 changes: 35 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@

from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
from airflow.utils import timezone as tz

if TYPE_CHECKING:
Expand Down Expand Up @@ -191,3 +192,36 @@ def subprocess_main():
assert spy.called_with(id, pid=proc.pid) # noqa: PGH005
# The exact number we get will depend on timing behaviour, so be a little lenient
assert 1 <= len(spy.calls) <= 4

def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
"""Test running a simple DAG in a subprocess and capturing the output."""

# Ignore anything lower than INFO for this test.
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))

instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
time_machine.move_to(instant, tick=False)

dagfile_path = test_dags_dir / "super_basic_run.py"
task_activity = ExecuteTaskActivity(
ti=TaskInstance(
id=UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"),
task_id="hello",
dag_id="super_basic_run",
run_id="c",
try_number=1,
),
path=dagfile_path,
token="",
)
# Assert Exit Code is 0
assert supervise(activity=task_activity, server="", dry_run=True) == 0

# We should have a log from the task!
assert {
"chan": "stdout",
"event": "Hello World hello!",
"level": "info",
"logger": "task",
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs
15 changes: 14 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import uuid
from pathlib import Path
from socket import socketpair
from unittest import mock

import pytest
from uuid6 import uuid7

from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run


class TestCommsDecoder:
Expand Down Expand Up @@ -73,3 +74,15 @@ def test_parse(test_dags_dir: Path):
assert ti.task.dag
assert isinstance(ti.task, BaseOperator)
assert isinstance(ti.task.dag, DAG)


def test_run_basic(test_dags_dir: Path):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
file=str(test_dags_dir / "super_basic_run.py"),
requests_fd=0,
)

ti = parse(what)
run(ti, log=mock.MagicMock())

0 comments on commit 8554001

Please sign in to comment.