Skip to content

Commit

Permalink
feature: Enable retries when retrieving results from AwsQuantumTaskBa…
Browse files Browse the repository at this point in the history
…tch. (#177)

* feature: Enable retries when retrieving results from AwsQuantumTaskBatch.

* fix: Allow AwsQuantumTaskBatch.result() to fail even when use_cached_value=True

* Minor changes
  • Loading branch information
licedric authored Nov 26, 2020
1 parent d5140f9 commit e2895a8
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 22 deletions.
99 changes: 78 additions & 21 deletions src/braket/aws/aws_quantum_task_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AwsQuantumTaskBatch:

MAX_PARALLEL_DEFAULT = 10
MAX_CONNECTIONS_DEFAULT = 100
MAX_RETRIES = 3

def __init__(
self,
Expand Down Expand Up @@ -96,7 +97,17 @@ def __init__(
self._results = None
self._unsuccessful = set()

# Cache execution inputs for retries.
self._device_arn = device_arn
self._task_specifications = task_specifications
self._s3_destination_folder = s3_destination_folder
self._shots = shots
self._max_parallel = max_parallel
self._max_workers = max_workers
self._poll_timeout_seconds = poll_timeout_seconds
self._poll_interval_seconds = poll_interval_seconds
self._aws_quantum_task_args = aws_quantum_task_args
self._aws_quantum_task_kwargs = aws_quantum_task_kwargs

@staticmethod
def _execute(
Expand Down Expand Up @@ -173,40 +184,88 @@ def _create_task(
time.sleep(poll_interval_seconds)
return task

def results(self, fail_unsuccessful=False, use_cached_value=True):
def results(self, fail_unsuccessful=False, max_retries=MAX_RETRIES, use_cached_value=True):
"""Retrieves the result of every task in the batch.
Polling for results happens in parallel; this method returns when all tasks
have reached a terminal state. The result of this method is cached.
Args:
fail_unsuccessful (bool): If set to True, this method will fail
if any task in the batch is in the FAILED or CANCELLED state.
if any task in the batch fails to return a result even after
`max_retries` retries.
max_retries (int): Maximum number of times to retry any failed tasks,
i.e. any tasks in the FAILED or CANCELLED state or that didn't
complete within the timeout. Default: 3
use_cached_value (bool): If False, will refetch the results from S3,
even when results have already been cached. Default: True
Returns:
List[AwsQuantumTask]: The results of all of the tasks in the batch.
FAILED or CANCELLED tasks will have a result of None
FAILED, CANCELLED, or timed out tasks will have a result of None
"""
if self._results and use_cached_value:
return list(self._results)
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
result_futures = [
executor.submit(AwsQuantumTaskBatch._get_task_result, task, self._unsuccessful)
for task in self._tasks
]
self._results = [future.result() for future in result_futures]
if not self._results or not use_cached_value:
self._results = AwsQuantumTaskBatch._retrieve_results(self._tasks, self._max_workers)
self._unsuccessful = {
task.id for task, result in zip(self._tasks, self._results) if not result
}

retries = 0
while self._unsuccessful and retries < max_retries:
self.retry_unsuccessful_tasks()
retries = retries + 1

if fail_unsuccessful and self._unsuccessful:
raise RuntimeError(f"{len(self._unsuccessful)} tasks failed to complete")
raise RuntimeError(
f"{len(self._unsuccessful)} tasks failed to complete after {max_retries} retries"
)
return self._results

@staticmethod
def _get_task_result(task, unsuccessful):
result = task.result()
if not result:
unsuccessful.add(task.id)
return result
def _retrieve_results(tasks, max_workers):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
result_futures = [executor.submit(task.result) for task in tasks]
return [future.result() for future in result_futures]

def retry_unsuccessful_tasks(self):
"""Retries any tasks in the batch without valid results.
This method should only be called after results() has been called at least once.
The method will generate new tasks for any failed tasks, so `self.task` and
`self.results()` may return different values after a call to this method.
Returns:
bool: Whether or not all retried tasks completed successfully.
"""
if not self._results:
raise RuntimeError("results() should be called before attempting to retry")
unsuccessful_indices = [index for index, result in enumerate(self._results) if not result]
# Return early if there's nothing to retry
if not unsuccessful_indices:
return True
retried_tasks = AwsQuantumTaskBatch._execute(
self._aws_session,
self._device_arn,
[self._task_specifications[i] for i in unsuccessful_indices],
self._s3_destination_folder,
self._shots,
self._max_parallel,
self._max_workers,
self._poll_timeout_seconds,
self._poll_interval_seconds,
*self._aws_quantum_task_args,
**self._aws_quantum_task_kwargs,
)
for index, task in zip(unsuccessful_indices, retried_tasks):
self._tasks[index] = task

retried_results = AwsQuantumTaskBatch._retrieve_results(retried_tasks, self._max_workers)
for index, result in zip(unsuccessful_indices, retried_results):
self._results[index] = result
self._unsuccessful = {
task.id for task, result in zip(retried_tasks, retried_results) if not result
}
return not self._unsuccessful

@property
def tasks(self) -> List[AwsQuantumTask]:
Expand All @@ -221,9 +280,7 @@ def size(self) -> int:
@property
def unfinished(self) -> Set[str]:
"""Set[str]: The IDs of all the tasks in the batch that have yet to complete"""
with ThreadPoolExecutor(
max_workers=AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT
) as executor:
with ThreadPoolExecutor(max_workers=self._max_workers) as executor:
status_futures = {task.id: executor.submit(task.state) for task in self._tasks}
unfinished = set()
for task_id in status_futures:
Expand All @@ -236,5 +293,5 @@ def unfinished(self) -> Set[str]:

@property
def unsuccessful(self) -> Set[str]:
"""Set[str]: The IDs of all the FAILED and CANCELLED tasks in the batch"""
"""Set[str]: The IDs of all the FAILED, CANCELLED, or timed out tasks in the batch."""
return set(self._unsuccessful)
32 changes: 31 additions & 1 deletion test/unit_tests/braket/aws/test_aws_quantum_task_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,42 @@ def test_unsuccessful(mock_create):
assert not batch.unfinished
assert batch.unsuccessful == {task_id}
assert batch.results() == [None]
assert batch.results(fail_unsuccessful=True) == [None] # Result is cached
with pytest.raises(RuntimeError):
assert batch.results(fail_unsuccessful=True) == [None]
batch._unsuccessful = set()
with pytest.raises(RuntimeError):
batch.results(fail_unsuccessful=True, use_cached_value=False)
assert batch.unsuccessful == {task_id}


@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create")
def test_retry(mock_create):
bad_task_mock = Mock()
type(bad_task_mock).id = PropertyMock(side_effect=uuid.uuid4)
bad_task_mock.state.return_value = random.choice(["CANCELLED", "FAILED"])
bad_task_mock.result.return_value = None

good_task_mock = Mock()
# task id already mocked when setting up bad_task_mock
good_task_mock.state.return_value = "COMPLETED"
result = GateModelQuantumTaskResult.from_string(MockS3.MOCK_S3_RESULT_GATE_MODEL)
good_task_mock.result.return_value = result

mock_create.side_effect = [bad_task_mock, good_task_mock, bad_task_mock, good_task_mock]

batch = AwsQuantumTaskBatch(
Mock(), "foo", [Circuit().h(0).cnot(0, 1), Circuit().h(1).cnot(0, 1)], S3_TARGET, 1000
)
assert not batch.unfinished
assert batch.results(max_retries=0) == [None, result]

# Retrying should get rid of the failures
assert batch.results(fail_unsuccessful=True, max_retries=3, use_cached_value=False) == [
result,
result,
]
assert batch.unsuccessful == set()


def _circuits(batch_size):
return [Circuit().h(0).cnot(0, 1) for _ in range(batch_size)]

0 comments on commit e2895a8

Please sign in to comment.