diff --git a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala index 5f2c49a9284..b1783b3ad2d 100644 --- a/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala +++ b/cromwell-drs-localizer/src/main/scala/drs/localizer/downloaders/GcsUriDownloader.scala @@ -57,36 +57,30 @@ case class GcsUriDownloader(gcsUrl: String, downloadAttempt: Int = 0 ): IO[DownloadResult] = { + // Necessary function to handle the throwable when trying to recover a failed download + def handleDownloadFailure(t: Throwable): IO[DownloadResult] = + downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) + + logger.info(s"Attempting download attempt $downloadAttempt of $downloadRetries for a GCS url") + if (downloadAttempt < downloadRetries) { backoff foreach { b => Thread.sleep(b.backoffMillis) } - logger.warn(s"Attempting download retry $downloadAttempt of $downloadRetries for a GCS url") - downloadWithRetries(downloadRetries, - backoff map { - _.next - }, - downloadAttempt + 1 + runDownloadCommand.redeemWith( + recover = handleDownloadFailure, + bind = { + case s: DownloadSuccess.type => + IO.pure(s) + case _: RecognizedRetryableDownloadFailure => + downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) + case _: UnrecognizedRetryableDownloadFailure => + downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) + case _ => + downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) + } ) } else { IO.raiseError(new RuntimeException(s"Exhausted $downloadRetries resolution retries to download GCS file")) } - - // Necessary function to handle the throwable when trying to recover a failed download - def handleDownloadFailure(t: Throwable): IO[DownloadResult] = - downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) - - runDownloadCommand.redeemWith( - recover = handleDownloadFailure, - bind = { - case s: DownloadSuccess.type => - IO.pure(s) - case _: RecognizedRetryableDownloadFailure => - downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) - case _: UnrecognizedRetryableDownloadFailure => - downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) - case _ => - downloadWithRetries(downloadRetries, backoff, downloadAttempt + 1) - } - ) } /** diff --git a/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/GcsUriDownloaderSpec.scala b/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/GcsUriDownloaderSpec.scala index 07eb5ede181..4bd2ad0e787 100644 --- a/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/GcsUriDownloaderSpec.scala +++ b/cromwell-drs-localizer/src/test/scala/drs/localizer/downloaders/GcsUriDownloaderSpec.scala @@ -1,6 +1,7 @@ package drs.localizer.downloaders import common.assertion.CromwellTimeoutSpec +import org.mockito.Mockito.{spy, times, verify} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -96,4 +97,25 @@ class GcsUriDownloaderSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat downloader.generateDownloadScript(gcsUrl, Option(fakeSAJsonPath)) shouldBe expectedDownloadScript } + + it should "fail to download GCS URL after 5 attempts" in { + val gcsUrl = "gs://foo/bar.bam" + val downloader = spy(new GcsUriDownloader( + gcsUrl = gcsUrl, + downloadLoc = fakeDownloadLocation, + requesterPaysProjectIdOption = Option(fakeRequesterPaysId), + serviceAccountJson = None + )) + + val result = downloader.downloadWithRetries(5, None).attempt.unsafeRunSync() + + result.isLeft shouldBe true + // attempts to download the 1st time and the 5th time, but doesn't attempt a 6th + verify(downloader, times(1)).downloadWithRetries(5, None, 1) + verify(downloader, times(1)).downloadWithRetries(5, None, 5) + verify(downloader, times(0)).downloadWithRetries(5, None, 6) + // attempts the actual download command 5 times + verify(downloader, times(5)).runDownloadCommand + + } }