Skip to content

Commit

Permalink
Refactoring Test Command (#1515)
Browse files Browse the repository at this point in the history
* refactor: test command demo

* refactor: test command demo

* refactor: test command demo

* refactor: eos-temp folder error fix

* refactor: eos-temp folder error fix

* refactor: test command demo

* refactor: test command final

* refactor: test command final

* refactor: test command final

* refactor: test command final

* refactor: test command final

* refactor: test command final

* refactor: test command final

* fix for version pinning and version flag from dockerhub

* fix for playground test failing

---------

Co-authored-by: Dhanshree Arora <[email protected]>
  • Loading branch information
Abellegese and DhanshreeA authored Jan 22, 2025
1 parent dfb49b8 commit 4befac5
Show file tree
Hide file tree
Showing 5 changed files with 1,496 additions and 824 deletions.
88 changes: 61 additions & 27 deletions ersilia/cli/commands/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ def test_cmd():
--------
.. code-block:: console
With default settings:
$ ersilia test my_model -d /path/to/model
With basic testing:/
$ ersilia test eosxxxx --from_dir /path/to/model
With deep testing level and inspect:
$ ersilia test my_model -d /path/to/model --level deep --inspect --remote
With different sources to fetch the model:
$ ersilia test eosxxxx --from_github/--from_dockerhub/--from_s3
With different levels of testing:
$ ersilia test eosxxxx --shallow --from_github/--from_dockerhub/--from_s3
"""

@ersilia_cli.command(
Expand All @@ -38,48 +41,79 @@ def test_cmd():
"-l",
"--level",
"level",
help="Level of testing, None: for default, deep: for deep testing",
help="Level of testing, None: for default, deep: for deep testing, shallow: for shallow testing",
required=False,
default=None,
type=click.STRING,
)
@click.option(
"-d",
"--dir",
"dir",
help="Model directory",
required=False,
"--from_dir",
default=None,
type=click.STRING,
help="Local path where the model is stored",
)
@click.option(
"--from_github",
is_flag=True,
default=False,
help="Fetch fetch directly from GitHub",
)
@click.option(
"--from_dockerhub",
is_flag=True,
default=False,
help="Force fetch from DockerHub",
)
@click.option(
"--from_s3", is_flag=True, default=False, help="Force fetch from AWS S3 bucket"
)
@click.option(
"--version",
default=None,
type=click.STRING,
help="Version of the model to fetch, when fetching a model from DockerHub",
)
@click.option(
"--inspect",
help="Inspect the model: More on the docs",
"--shallow",
is_flag=True,
default=False,
help="This flag is used to check shallow checks (such as container size, output consistency..)",
)
@click.option(
"--remote",
help="Test the model from remote git repository",
"--deep",
is_flag=True,
default=False,
help="This flag is used to check deep checks (such as computational performance checks)",
)
@click.option(
"--remove",
help="Remove the model directory after testing",
"--as_json",
is_flag=True,
default=False,
help="This flag is used to save the report as json file)",
)
def test(model, level, dir, inspect, remote, remove):
def test(
model,
level,
from_dir,
from_github,
from_dockerhub,
from_s3,
version,
shallow,
deep,
as_json,
):
mt = ModelTester(
model_id=model,
level=level,
dir=dir,
inspect=inspect,
remote=remote,
remove=remove,
model,
level,
from_dir,
from_github,
from_dockerhub,
from_s3,
version,
shallow,
deep,
as_json,
)
echo("Setting up model tester...")
mt.setup()
echo("Testing model...")
mt.run(output_file=None)
echo(f"Model testing started for: {model}")
mt.run()
1 change: 1 addition & 0 deletions ersilia/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# EOS environmental variables
EOS = os.path.join(str(Path.home()), "eos")
EOS_TMP = os.path.join(EOS, "temp")
if not os.path.exists(EOS):
os.makedirs(EOS)
ROOT = os.path.dirname(os.path.realpath(__file__))
Expand Down
90 changes: 70 additions & 20 deletions ersilia/publish/inspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import subprocess
import time
from collections import namedtuple
Expand Down Expand Up @@ -387,23 +388,42 @@ def validate_repo_structure(self):

def _validate_dockerfile(self, dockerfile_content):
lines, errors = dockerfile_content.splitlines(), []
for line in lines:
if line.startswith("RUN pip install"):
cmd = line.split("RUN ")[-1]
result = subprocess.run(
cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
errors.append(f"Failed to run {cmd}: {result.stderr.strip()}")

if "WORKDIR /repo" not in dockerfile_content:
errors.append("Missing 'WORKDIR /repo'.")
if "COPY . /repo" not in dockerfile_content:
errors.append("Missing 'COPY . /repo'.")

pip_install_pattern = re.compile(r"pip install (.+)")
version_pin_pattern = re.compile(r"^[a-zA-Z0-9_\-\.]+==[a-zA-Z0-9_\-\.]+$")

for line in lines:
line = line.strip()

match = pip_install_pattern.search(line)
if match:
packages_and_options = match.group(1).split()
skip_next = False

for item in packages_and_options:
if skip_next:
skip_next = False
continue

if item.startswith("--index-url") or item.startswith(
"--extra-index-url"
):
skip_next = True
continue

if item.startswith("git+"):
continue

if not version_pin_pattern.match(item):
errors.append(
f"Package '{item}' in line '{line}' is not version-pinned (e.g., 'package==1.0.0')."
)

return errors

def _validate_yml(self, yml_content):
Expand All @@ -417,18 +437,48 @@ def _validate_yml(self, yml_content):
if not python_version:
errors.append("Missing Python version in install.yml.")

version_pin_pattern = re.compile(r"^[a-zA-Z0-9_\-\.]+==[a-zA-Z0-9_\-\.]+$")

commands = yml_data.get("commands", [])
for command in commands:
if not isinstance(command, list) or command[0] != "pip":
if not isinstance(command, list) or len(command) < 2:
errors.append(f"Invalid command format: {command}")
continue
# package: name & version
name = command[1] if len(command) > 1 else None

tool = command[0]
_ = command[1]
version = command[2] if len(command) > 2 else None
if not name:
errors.append(f"Missing package name in command: {command}")
if name and version:
pass

if tool in ("pip", "conda"):
if tool == "pip":
pip_args = command[1:]
skip_next = False

for item in pip_args:
if skip_next:
skip_next = False
continue

if item.startswith("--index-url") or item.startswith(
"--extra-index-url"
):
skip_next = True
continue

if item.startswith("git+"):
continue

if not version_pin_pattern.match(item):
errors.append(
f"Package '{item}' in command '{command}' is not version-pinned (e.g., 'package==1.0.0')."
)

elif tool == "conda" and not version:
errors.append(
f"Package in command '{command}' does not have a valid pinned version "
f"(should be in the format ['conda', 'package_name', 'x.y.z'])."
)

return errors

def _run_performance_check(self, n):
Expand All @@ -445,5 +495,5 @@ def _run_performance_check(self, n):
return Result(False, f"Error serving model: {process.stderr.strip()}")
execution_time = time.time() - start_time
return Result(
True, f"{n} predictions executed in {execution_time:.2f} seconds."
True, f"{n} predictions executed in {execution_time:.2f} seconds. \n"
)
Loading

0 comments on commit 4befac5

Please sign in to comment.