diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index ef70b93599..783d034011 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf output_file_found = threading.Event() failure_file_found = threading.Event() + waiter_error_catched = threading.Event() def check_output_file(): try: @@ -282,7 +283,7 @@ def check_output_file(): ) output_file_found.set() except WaiterError: - pass + waiter_error_catched.set() def check_failure_file(): try: @@ -294,7 +295,7 @@ def check_failure_file(): ) failure_file_found.set() except WaiterError: - pass + waiter_error_catched.set() output_thread = threading.Thread(target=check_output_file) failure_thread = threading.Thread(target=check_failure_file) @@ -302,7 +303,11 @@ def check_failure_file(): output_thread.start() failure_thread.start() - while not output_file_found.is_set() and not failure_file_found.is_set(): + while ( + not output_file_found.is_set() + and not failure_file_found.is_set() + and not waiter_error_catched.is_set() + ): time.sleep(1) if output_file_found.is_set(): @@ -310,17 +315,15 @@ def check_failure_file(): result = self.predictor._handle_response(response=s3_object) return result - failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) - failure_response = self.predictor._handle_response(response=failure_object) + if failure_file_found.is_set(): + failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) + failure_response = self.predictor._handle_response(response=failure_object) + raise AsyncInferenceModelError(message=failure_response) - raise ( - AsyncInferenceModelError(message=failure_response) - if failure_file_found.is_set() - else PollingTimeoutError( - message="Inference could still be running", - output_path=output_path, - seconds=waiter_config.delay * waiter_config.max_attempts, - ) + raise PollingTimeoutError( + message="Inference could still be running", + output_path=output_path, + seconds=waiter_config.delay * waiter_config.max_attempts, ) def update_endpoint(