Skip to content

Commit

Permalink
Add support for timeout to BatchOperator (#45660)
Browse files Browse the repository at this point in the history
An execution timeout for the submit_job api call can now be passed through the operator to the boto3 call.
  • Loading branch information
nrobinson-intelycare authored Jan 17, 2025
1 parent 418b701 commit caa401d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
6 changes: 6 additions & 0 deletions providers/src/airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class BatchOperator(BaseOperator):
If it is an array job, only the logs of the first task will be printed.
:param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec.
:param poll_interval: (Deferrable mode only) Time in seconds to wait between polling.
:param submit_job_timeout: Execution timeout in seconds for submitted batch job.
.. note::
Any custom waiters must return a waiter for these calls:
Expand Down Expand Up @@ -184,6 +185,7 @@ def __init__(
poll_interval: int = 30,
awslogs_enabled: bool = False,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
submit_job_timeout: int | None = None,
**kwargs,
) -> None:
BaseOperator.__init__(self, **kwargs)
Expand All @@ -208,6 +210,7 @@ def __init__(
self.poll_interval = poll_interval
self.awslogs_enabled = awslogs_enabled
self.awslogs_fetch_interval = awslogs_fetch_interval
self.submit_job_timeout = submit_job_timeout

# params for hook
self.max_retries = max_retries
Expand Down Expand Up @@ -315,6 +318,9 @@ def submit_job(self, context: Context):
"schedulingPriorityOverride": self.scheduling_priority_override,
}

if self.submit_job_timeout:
args["timeout"] = {"attemptDurationSeconds": self.submit_job_timeout}

try:
response = self.hook.client.submit_job(**trim_none_values(args))
except Exception as e:
Expand Down
7 changes: 7 additions & 0 deletions providers/tests/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def setup_method(self, _, get_client_type_mock):
aws_conn_id="airflow_test",
region_name="eu-west-1",
tags={},
submit_job_timeout=3600,
)
self.client_mock = self.get_client_type_mock.return_value
# We're mocking all actual AWS calls and don't need a connection. This
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_init(self):
assert self.batch.hook.client == self.client_mock
assert self.batch.tags == {}
assert self.batch.wait_for_completion is True
assert self.batch.submit_job_timeout == 3600

self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1")

Expand Down Expand Up @@ -141,6 +143,7 @@ def test_init_defaults(self):
assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient)
assert batch_job.tags == {}
assert batch_job.wait_for_completion is True
assert batch_job.submit_job_timeout is None

def test_template_fields_overrides(self):
assert self.batch.template_fields == (
Expand Down Expand Up @@ -181,6 +184,7 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m
parameters={},
retryStrategy={"attempts": 1},
tags={},
timeout={"attemptDurationSeconds": 3600},
)

assert self.batch.job_id == JOB_ID
Expand All @@ -205,6 +209,7 @@ def test_execute_with_failures(self):
parameters={},
retryStrategy={"attempts": 1},
tags={},
timeout={"attemptDurationSeconds": 3600},
)

@mock.patch.object(BatchClientHook, "get_job_description")
Expand Down Expand Up @@ -261,6 +266,7 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description
parameters={},
retryStrategy={"attempts": 1},
tags={},
timeout={"attemptDurationSeconds": 3600},
)

@mock.patch.object(BatchClientHook, "get_job_description")
Expand Down Expand Up @@ -359,6 +365,7 @@ def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description
parameters={},
retryStrategy={"attempts": 1},
tags={},
timeout={"attemptDurationSeconds": 3600},
)

@mock.patch.object(BatchClientHook, "check_job_success")
Expand Down

0 comments on commit caa401d

Please sign in to comment.