Skip to content

Commit

Permalink
Support versioning of Dockerized models (#1481)
Browse files Browse the repository at this point in the history
* Add functionality to specify image tags and use the latest default tag otherwise

* format
  • Loading branch information
DhanshreeA authored Jan 3, 2025
1 parent c41fb34 commit 3e6028e
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 49 deletions.
8 changes: 8 additions & 0 deletions ersilia/cli/commands/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def _fetch(mf, model_id):
default=False,
help="Force fetch from DockerHub",
)
@click.option(
"--version",
default=None,
type=click.STRING,
help="Version of the model to fetch, when fetching a model from DockerHub",
)
@click.option(
"--from_s3", is_flag=True, default=False, help="Force fetch from AWS S3 bucket"
)
Expand Down Expand Up @@ -101,6 +107,7 @@ def fetch(
from_dir,
from_github,
from_dockerhub,
version,
from_s3,
from_hosted,
hosted_url,
Expand All @@ -125,6 +132,7 @@ def fetch(
force_from_github=from_github,
force_from_s3=from_s3,
force_from_dockerhub=from_dockerhub,
img_version=version,
force_from_hosted=from_hosted,
force_with_bentoml=with_bentoml,
force_with_fastapi=with_fastapi,
Expand Down
6 changes: 3 additions & 3 deletions ersilia/core/modelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from .. import ErsiliaBase, throw_ersilia_exception
from ..default import IS_FETCHED_FROM_DOCKERHUB_FILE
from ..default import DOCKER_INFO_FILE
from ..hub.content.slug import Slug
from ..hub.fetch import DONE_TAG, STATUS_FILE
from ..utils.exceptions_utils.exceptions import InvalidModelIdentifierError
Expand Down Expand Up @@ -104,7 +104,7 @@ def _is_available_locally_from_status(self):

def _is_available_locally_from_dockerhub(self):
from_dockerhub_file = os.path.join(
self._dest_dir, self.model_id, IS_FETCHED_FROM_DOCKERHUB_FILE
self._dest_dir, self.model_id, DOCKER_INFO_FILE
)
if not os.path.exists(from_dockerhub_file):
return False
Expand Down Expand Up @@ -138,7 +138,7 @@ def was_fetched_from_dockerhub(self):
True if the model was fetched from DockerHub, False otherwise.
"""
from_dockerhub_file = os.path.join(
self._dest_dir, self.model_id, IS_FETCHED_FROM_DOCKERHUB_FILE
self._dest_dir, self.model_id, DOCKER_INFO_FILE
)
if not os.path.exists(from_dockerhub_file):
return False
Expand Down
14 changes: 9 additions & 5 deletions ersilia/db/environments/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...core.base import ErsiliaBase
from ...default import DOCKERHUB_LATEST_TAG, DOCKERHUB_ORG
from ...setup.requirements.docker import DockerRequirement
from ...utils.docker import SimpleDocker, resolve_platform
from ...utils.docker import SimpleDocker, model_image_version_reader, resolve_platform
from ...utils.identifiers.short import ShortIdentifier
from ...utils.logging import make_temp_dir
from ...utils.paths import Paths
Expand Down Expand Up @@ -235,7 +235,7 @@ def containers_of_model(self, model_id, only_run, only_latest=True):
cnt_dict[k] = v
return cnt_dict

def build_with_bentoml(self, model_id, use_cache=True):
def build_with_bentoml(self, model_id, use_cache=True): # Ignore for versioning
"""
Builds a Docker image for the model using BentoML.
Expand Down Expand Up @@ -295,7 +295,9 @@ def _build_ersilia_base(self):
)
run_command(cmd)

def build_with_ersilia(self, model_id, docker_user, docker_pwd):
def build_with_ersilia(
self, model_id, docker_user, docker_pwd
): # Ignore for versioning
"""
Builds a Docker image for the model using Ersilia's base image.
Expand Down Expand Up @@ -383,10 +385,12 @@ def remove(self, model_id):
model_id : str
Identifier of the model.
"""
self.docker.delete(org=DOCKERHUB_ORG, img=model_id, tag=DOCKERHUB_LATEST_TAG)
bundle_path = self._get_bundle_location(model_id)
docker_tag = model_image_version_reader(bundle_path)
self.docker.delete(org=DOCKERHUB_ORG, img=model_id, tag=docker_tag)
self.db.delete(
model_id=model_id,
env="{0}/{1}:{2}".format(DOCKERHUB_ORG, model_id, DOCKERHUB_LATEST_TAG),
env="{0}/{1}:{2}".format(DOCKERHUB_ORG, model_id, docker_tag),
)

def run(self, model_id, workers=1, enable_microbatch=True, memory=None):
Expand Down
2 changes: 1 addition & 1 deletion ersilia/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
MODEL_SOURCE_FILE = "model_source.txt"
APIS_LIST_FILE = "apis_list.txt"
INFORMATION_FILE = "information.json"
IS_FETCHED_FROM_DOCKERHUB_FILE = "from_dockerhub.json"
DOCKER_INFO_FILE = "from_dockerhub.json"
IS_FETCHED_FROM_HOSTED_FILE = "from_hosted.json"
DEFAULT_UDOCKER_USERNAME = "udockerusername"
DEFAULT_UDOCKER_PASSWORD = "udockerpassword"
Expand Down
4 changes: 2 additions & 2 deletions ersilia/hub/bundle/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ... import ErsiliaBase
from ...db.environments.localdb import EnvironmentDb
from ...default import IS_FETCHED_FROM_DOCKERHUB_FILE
from ...default import DOCKER_INFO_FILE
from ...utils.conda import SimpleConda
from ...utils.docker import SimpleDocker

Expand Down Expand Up @@ -89,7 +89,7 @@ def is_pulled_docker(self, model_id: str) -> bool:
True if the Docker image has been pulled, False otherwise.
"""
model_dir = os.path.join(self._model_path(model_id=model_id))
json_file = os.path.join(model_dir, IS_FETCHED_FROM_DOCKERHUB_FILE)
json_file = os.path.join(model_dir, DOCKER_INFO_FILE)
if not os.path.exists(json_file):
return False
with open(json_file, "r") as f:
Expand Down
5 changes: 4 additions & 1 deletion ersilia/hub/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ModelFetcher(ErsiliaBase):
Whether to force fetching from S3.
force_from_dockerhub : bool, optional
Whether to force fetching from DockerHub.
img_version : str, optional
Version of the model image.
force_from_hosted : bool, optional
Whether to force fetching from hosted services.
force_with_bentoml : bool, optional
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
force_from_github: bool = False,
force_from_s3: bool = False,
force_from_dockerhub: bool = False,
img_version: str = None,
force_from_hosted: bool = False,
force_with_bentoml: bool = False,
force_with_fastapi: bool = False,
Expand All @@ -100,7 +103,7 @@ def __init__(
dockerize = True
self.do_docker = dockerize
self.model_dockerhub_fetcher = ModelDockerHubFetcher(
overwrite=self.overwrite, config_json=self.config_json
overwrite=self.overwrite, config_json=self.config_json, img_tag=img_version
)
self.is_docker_installed = self.model_dockerhub_fetcher.is_docker_installed()
self.is_docker_active = self.model_dockerhub_fetcher.is_docker_active()
Expand Down
31 changes: 22 additions & 9 deletions ersilia/hub/fetch/lazy_fetchers/dockerhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ class ModelDockerHubFetcher(ErsiliaBase):
Fetch the model from DockerHub.
"""

def __init__(self, overwrite=None, config_json=None):
def __init__(self, overwrite=None, config_json=None, img_tag=None):
super().__init__(config_json=config_json, credentials_json=None)
self.simple_docker = SimpleDocker()
self.overwrite = overwrite
self.img_tag = img_tag or DOCKERHUB_LATEST_TAG
self.pack_method = None

def is_docker_installed(self) -> bool:
"""
Expand Down Expand Up @@ -101,7 +103,10 @@ def is_available(self, model_id: str) -> bool:
True if the model is available, False otherwise.
"""
mp = ModelPuller(
model_id=model_id, overwrite=self.overwrite, config_json=self.config_json
model_id=model_id,
overwrite=self.overwrite,
config_json=self.config_json,
docker_tag=self.img_tag,
)
if mp.is_available_locally():
return True
Expand Down Expand Up @@ -144,7 +149,7 @@ async def _copy_from_bentoml_image(self, model_id: str, file: str):
local_path=to_file,
org=DOCKERHUB_ORG,
img=model_id,
tag=DOCKERHUB_LATEST_TAG,
tag=self.img_tag,
)
except Exception as e:
self.logger.error(f"Exception when copying: {e}")
Expand All @@ -167,7 +172,7 @@ async def _copy_from_ersiliapack_image(self, model_id: str, file: str):
local_path=to_file,
org=DOCKERHUB_ORG,
img=model_id,
tag=DOCKERHUB_LATEST_TAG,
tag=self.img_tag,
)

async def _copy_from_image_to_local(self, model_id: str, file: str):
Expand All @@ -181,8 +186,12 @@ async def _copy_from_image_to_local(self, model_id: str, file: str):
file : str
Name of the file to copy.
"""
pack_method = resolve_pack_method_docker(model_id)
if pack_method == PACK_METHOD_BENTOML:
if not self.pack_method:
self.logger.debug("Resolving pack method")
self.pack_method = resolve_pack_method_docker(model_id)
self.logger.debug(f"Resolved pack method: {self.pack_method}")

if self.pack_method == PACK_METHOD_BENTOML:
await self._copy_from_bentoml_image(model_id, file)
else:
await self._copy_from_ersiliapack_image(model_id, file)
Expand Down Expand Up @@ -245,7 +254,9 @@ async def modify_information(self, model_id: str):
ID of the model.
"""
information_file = os.path.join(self._model_path(model_id), INFORMATION_FILE)
mp = ModelPuller(model_id=model_id, config_json=self.config_json)
mp = ModelPuller(
model_id=model_id, config_json=self.config_json, docker_tag=self.img_tag
)
try:
with open(information_file, "r") as infile:
data = json.load(infile)
Expand All @@ -268,13 +279,15 @@ async def fetch(self, model_id: str):
model_id : str
ID of the model.
"""
mp = ModelPuller(model_id=model_id, config_json=self.config_json)
mp = ModelPuller(
model_id=model_id, config_json=self.config_json, docker_tag=self.img_tag
)
self.logger.debug("Pulling model image from DockerHub")
await mp.async_pull()
mr = ModelRegisterer(model_id=model_id, config_json=self.config_json)
self.logger.debug("Asynchronous and concurrent execution started!")
await asyncio.gather(
mr.register(is_from_dockerhub=True),
mr.register(is_from_dockerhub=True, img_tag=self.img_tag),
self.write_apis(model_id),
self.copy_information(model_id),
self.modify_information(model_id),
Expand Down
22 changes: 13 additions & 9 deletions ersilia/hub/fetch/register/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .... import EOS, ErsiliaBase, throw_ersilia_exception
from ....default import (
IS_FETCHED_FROM_DOCKERHUB_FILE,
DOCKER_INFO_FILE,
IS_FETCHED_FROM_HOSTED_FILE,
SERVICE_CLASS_FILE,
)
Expand Down Expand Up @@ -40,13 +40,17 @@ def __init__(self, model_id: str, config_json: dict):
ErsiliaBase.__init__(self, config_json=config_json, credentials_json=None)
self.model_id = model_id

def register_from_dockerhub(self):
def register_from_dockerhub(self, **kwargs):
"""
Register the model from DockerHub.
This method registers the model in the file system indicating it was fetched from DockerHub.
"""
data = {"docker_hub": True}
if 'img_tag' in kwargs:
img_tag = kwargs['img_tag']
data = {"docker_hub": True, "tag": img_tag}
else:
data = {"docker_hub": True}
self.logger.debug(
"Registering model {0} in the file system".format(self.model_id)
)
Expand All @@ -55,7 +59,7 @@ def register_from_dockerhub(self):
if os.path.exists(path):
shutil.rmtree(path)
os.mkdir(path)
file_name = os.path.join(path, IS_FETCHED_FROM_DOCKERHUB_FILE)
file_name = os.path.join(path, DOCKER_INFO_FILE)
self.logger.debug(file_name)
with open(file_name, "w") as f:
json.dump(data, f)
Expand All @@ -66,7 +70,7 @@ def register_from_dockerhub(self):
shutil.rmtree(path)
path = os.path.join(path, folder_name)
os.makedirs(path)
file_name = os.path.join(path, IS_FETCHED_FROM_DOCKERHUB_FILE)
file_name = os.path.join(path, DOCKER_INFO_FILE)
with open(file_name, "w") as f:
json.dump(data, f)
file_name = os.path.join(path, SERVICE_CLASS_FILE)
Expand All @@ -82,11 +86,11 @@ def register_not_from_dockerhub(self):
"""
data = {"docker_hub": False}
path = self._model_path(self.model_id)
file_name = os.path.join(path, IS_FETCHED_FROM_DOCKERHUB_FILE)
file_name = os.path.join(path, DOCKER_INFO_FILE)
with open(file_name, "w") as f:
json.dump(data, f)
path = self._get_bundle_location(model_id=self.model_id)
file_name = os.path.join(path, IS_FETCHED_FROM_DOCKERHUB_FILE)
file_name = os.path.join(path, DOCKER_INFO_FILE)
with open(file_name, "w") as f:
json.dump(data, f)

Expand Down Expand Up @@ -174,7 +178,7 @@ def register_not_from_hosted(self):
json.dump(data, f)

async def register(
self, is_from_dockerhub: bool = False, is_from_hosted: bool = False
self, is_from_dockerhub: bool = False, is_from_hosted: bool = False, **kwargs
):
"""
Register the model based on its source.
Expand Down Expand Up @@ -205,7 +209,7 @@ async def register(
if is_from_dockerhub and is_from_hosted:
raise ValueError("Model cannot be from both DockerHub and hosted")
elif is_from_dockerhub and not is_from_hosted:
self.register_from_dockerhub()
self.register_from_dockerhub(**kwargs)
self.register_not_from_hosted()
elif not is_from_dockerhub and is_from_hosted:
self.register_from_hosted()
Expand Down
Loading

0 comments on commit 3e6028e

Please sign in to comment.