diff --git a/ersilia/hub/pull/pull.py b/ersilia/hub/pull/pull.py index 4732eed99..442dd9900 100644 --- a/ersilia/hub/pull/pull.py +++ b/ersilia/hub/pull/pull.py @@ -113,41 +113,29 @@ async def async_pull(self): self.logger.debug( "Trying to pull image {0}/{1}".format(DOCKERHUB_ORG, self.model_id) ) - tmp_file = os.path.join( - make_temp_dir(prefix="ersilia-"), "docker_pull.log" - ) - self.logger.debug("Keeping logs of pull in {0}".format(tmp_file)) - # Construct the pull command - pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG} > {tmp_file} 2>&1" + pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG}" - # Use asyncio to run the pull command asynchronously process = await asyncio.create_subprocess_shell( pull_command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) - # Wait for the command to complete - stdout, stderr = await process.communicate() + async def log_stream(stream, log_method): + async for line in stream: + log_method(line.decode().strip()) - # Handle output - if process.returncode != 0: - self.logger.error(f"Pull command failed: {stderr.decode()}") - raise subprocess.CalledProcessError(process.returncode, pull_command) - - self.logger.debug(stdout.decode()) + await asyncio.gather( + log_stream(process.stdout, self.logger.info), + log_stream(process.stderr, self.logger.error) + ) - # Reading log asynchronously - async with aiofiles.open(tmp_file, 'r') as f: - pull_log = await f.read() - self.logger.debug(pull_log) + await process.wait() - if re.search(r"no match.*manifest", pull_log): - self.logger.warning( - "No matching manifest for image {0}".format(self.model_id) - ) - raise DockerConventionalPullError(model=self.model_id) + if process.returncode != 0: + self.logger.error(f"Pull command failed with return code {process.returncode}") + raise subprocess.CalledProcessError(process.returncode, pull_command) self.logger.debug("Image pulled successfully!") @@ -155,36 +143,37 @@ async def async_pull(self): self.logger.warning( "Conventional pull did not work, Ersilia is now forcing linux/amd64 architecture" ) - # Force platform specification pull command force_pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG} --platform linux/amd64" - # Run forced pull asynchronously process = await asyncio.create_subprocess_shell( force_pull_command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) - stdout, stderr = await process.communicate() + await asyncio.gather( + log_stream(process.stdout, self.logger.info), + log_stream(process.stderr, self.logger.error) + ) + + await process.wait() if process.returncode != 0: - self.logger.error(f"Forced pull command failed: {stderr.decode()}") + self.logger.error(f"Forced pull command failed with return code {process.returncode}") raise subprocess.CalledProcessError(process.returncode, force_pull_command) - self.logger.debug(stdout.decode()) + self.logger.debug("Forced pull completed successfully!") + size = self._get_size_of_local_docker_image_in_mb() if size: self.logger.debug("Size of image {0} MB".format(size)) - # path = os.path.join(self._model_path(self.model_id), MODEL_SIZE_FILE) - # with open(path, "w") as f: - # json.dump({"size": size, "units": "MB"}, f, indent=4) - # self.logger.debug("Size written to {}".format(path)) else: self.logger.warning("Could not obtain size of image") return size else: self.logger.info("Image {0} is not available".format(self.image_name)) raise DockerImageNotAvailableError(model=self.model_id) + @throw_ersilia_exception() def pull(self):