Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging at realtime from the async subprocess while pulling #1392

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 22 additions & 33 deletions ersilia/hub/pull/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,78 +113,67 @@ async def async_pull(self):
self.logger.debug(
"Trying to pull image {0}/{1}".format(DOCKERHUB_ORG, self.model_id)
)
tmp_file = os.path.join(
make_temp_dir(prefix="ersilia-"), "docker_pull.log"
)
self.logger.debug("Keeping logs of pull in {0}".format(tmp_file))

# Construct the pull command
pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG} > {tmp_file} 2>&1"
pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG}"

# Use asyncio to run the pull command asynchronously
process = await asyncio.create_subprocess_shell(
pull_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)

# Wait for the command to complete
stdout, stderr = await process.communicate()
async def log_stream(stream, log_method):
async for line in stream:
log_method(line.decode().strip())

# Handle output
if process.returncode != 0:
self.logger.error(f"Pull command failed: {stderr.decode()}")
raise subprocess.CalledProcessError(process.returncode, pull_command)

self.logger.debug(stdout.decode())
await asyncio.gather(
log_stream(process.stdout, self.logger.info),
log_stream(process.stderr, self.logger.error)
)

# Reading log asynchronously
async with aiofiles.open(tmp_file, 'r') as f:
pull_log = await f.read()
self.logger.debug(pull_log)
await process.wait()

if re.search(r"no match.*manifest", pull_log):
self.logger.warning(
"No matching manifest for image {0}".format(self.model_id)
)
raise DockerConventionalPullError(model=self.model_id)
if process.returncode != 0:
self.logger.error(f"Pull command failed with return code {process.returncode}")
raise subprocess.CalledProcessError(process.returncode, pull_command)

self.logger.debug("Image pulled successfully!")

except DockerConventionalPullError:
self.logger.warning(
"Conventional pull did not work, Ersilia is now forcing linux/amd64 architecture"
)
# Force platform specification pull command
force_pull_command = f"docker pull {DOCKERHUB_ORG}/{self.model_id}:{DOCKERHUB_LATEST_TAG} --platform linux/amd64"

# Run forced pull asynchronously
process = await asyncio.create_subprocess_shell(
force_pull_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)

stdout, stderr = await process.communicate()
await asyncio.gather(
log_stream(process.stdout, self.logger.info),
log_stream(process.stderr, self.logger.error)
)

await process.wait()

if process.returncode != 0:
self.logger.error(f"Forced pull command failed: {stderr.decode()}")
self.logger.error(f"Forced pull command failed with return code {process.returncode}")
raise subprocess.CalledProcessError(process.returncode, force_pull_command)

self.logger.debug(stdout.decode())
self.logger.debug("Forced pull completed successfully!")

size = self._get_size_of_local_docker_image_in_mb()
if size:
self.logger.debug("Size of image {0} MB".format(size))
# path = os.path.join(self._model_path(self.model_id), MODEL_SIZE_FILE)
# with open(path, "w") as f:
# json.dump({"size": size, "units": "MB"}, f, indent=4)
# self.logger.debug("Size written to {}".format(path))
else:
self.logger.warning("Could not obtain size of image")
return size
else:
self.logger.info("Image {0} is not available".format(self.image_name))
raise DockerImageNotAvailableError(model=self.model_id)


@throw_ersilia_exception()
def pull(self):
Expand Down
Loading