diff --git a/ersilia/cli/commands/test.py b/ersilia/cli/commands/test.py index 9bc58e966..66a3d0ea4 100644 --- a/ersilia/cli/commands/test.py +++ b/ersilia/cli/commands/test.py @@ -20,11 +20,14 @@ def test_cmd(): -------- .. code-block:: console - With default settings: - $ ersilia test my_model -d /path/to/model + With basic testing:/ + $ ersilia test eosxxxx --from_dir /path/to/model - With deep testing level and inspect: - $ ersilia test my_model -d /path/to/model --level deep --inspect --remote + With different sources to fetch the model: + $ ersilia test eosxxxx --from_github/--from_dockerhub/--from_s3 + + With different levels of testing: + $ ersilia test eosxxxx --shallow --from_github/--from_dockerhub/--from_s3 """ @ersilia_cli.command( @@ -38,48 +41,79 @@ def test_cmd(): "-l", "--level", "level", - help="Level of testing, None: for default, deep: for deep testing", + help="Level of testing, None: for default, deep: for deep testing, shallow: for shallow testing", required=False, default=None, type=click.STRING, ) @click.option( - "-d", - "--dir", - "dir", - help="Model directory", - required=False, + "--from_dir", + default=None, + type=click.STRING, + help="Local path where the model is stored", + ) + @click.option( + "--from_github", + is_flag=True, + default=False, + help="Fetch fetch directly from GitHub", + ) + @click.option( + "--from_dockerhub", + is_flag=True, + default=False, + help="Force fetch from DockerHub", + ) + @click.option( + "--from_s3", is_flag=True, default=False, help="Force fetch from AWS S3 bucket" + ) + @click.option( + "--version", default=None, type=click.STRING, + help="Version of the model to fetch, when fetching a model from DockerHub", ) @click.option( - "--inspect", - help="Inspect the model: More on the docs", + "--shallow", is_flag=True, default=False, + help="This flag is used to check shallow checks (such as container size, output consistency..)", ) @click.option( - "--remote", - help="Test the model from remote git repository", + "--deep", is_flag=True, default=False, + help="This flag is used to check deep checks (such as computational performance checks)", ) @click.option( - "--remove", - help="Remove the model directory after testing", + "--as_json", is_flag=True, default=False, + help="This flag is used to save the report as json file)", ) - def test(model, level, dir, inspect, remote, remove): + def test( + model, + level, + from_dir, + from_github, + from_dockerhub, + from_s3, + version, + shallow, + deep, + as_json, + ): mt = ModelTester( - model_id=model, - level=level, - dir=dir, - inspect=inspect, - remote=remote, - remove=remove, + model, + level, + from_dir, + from_github, + from_dockerhub, + from_s3, + version, + shallow, + deep, + as_json, ) - echo("Setting up model tester...") - mt.setup() - echo("Testing model...") - mt.run(output_file=None) + echo(f"Model testing started for: {model}") + mt.run() diff --git a/ersilia/default.py b/ersilia/default.py index 0bd916975..3fafa049a 100644 --- a/ersilia/default.py +++ b/ersilia/default.py @@ -7,6 +7,7 @@ # EOS environmental variables EOS = os.path.join(str(Path.home()), "eos") +EOS_TMP = os.path.join(EOS, "temp") if not os.path.exists(EOS): os.makedirs(EOS) ROOT = os.path.dirname(os.path.realpath(__file__)) diff --git a/ersilia/publish/inspect.py b/ersilia/publish/inspect.py index 9bb1f4024..1e81e7f93 100644 --- a/ersilia/publish/inspect.py +++ b/ersilia/publish/inspect.py @@ -1,4 +1,5 @@ import os +import re import subprocess import time from collections import namedtuple @@ -387,23 +388,42 @@ def validate_repo_structure(self): def _validate_dockerfile(self, dockerfile_content): lines, errors = dockerfile_content.splitlines(), [] - for line in lines: - if line.startswith("RUN pip install"): - cmd = line.split("RUN ")[-1] - result = subprocess.run( - cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - if result.returncode != 0: - errors.append(f"Failed to run {cmd}: {result.stderr.strip()}") if "WORKDIR /repo" not in dockerfile_content: errors.append("Missing 'WORKDIR /repo'.") if "COPY . /repo" not in dockerfile_content: errors.append("Missing 'COPY . /repo'.") + + pip_install_pattern = re.compile(r"pip install (.+)") + version_pin_pattern = re.compile(r"^[a-zA-Z0-9_\-\.]+==[a-zA-Z0-9_\-\.]+$") + + for line in lines: + line = line.strip() + + match = pip_install_pattern.search(line) + if match: + packages_and_options = match.group(1).split() + skip_next = False + + for item in packages_and_options: + if skip_next: + skip_next = False + continue + + if item.startswith("--index-url") or item.startswith( + "--extra-index-url" + ): + skip_next = True + continue + + if item.startswith("git+"): + continue + + if not version_pin_pattern.match(item): + errors.append( + f"Package '{item}' in line '{line}' is not version-pinned (e.g., 'package==1.0.0')." + ) + return errors def _validate_yml(self, yml_content): @@ -417,18 +437,48 @@ def _validate_yml(self, yml_content): if not python_version: errors.append("Missing Python version in install.yml.") + version_pin_pattern = re.compile(r"^[a-zA-Z0-9_\-\.]+==[a-zA-Z0-9_\-\.]+$") + commands = yml_data.get("commands", []) for command in commands: - if not isinstance(command, list) or command[0] != "pip": + if not isinstance(command, list) or len(command) < 2: errors.append(f"Invalid command format: {command}") continue - # package: name & version - name = command[1] if len(command) > 1 else None + + tool = command[0] + _ = command[1] version = command[2] if len(command) > 2 else None - if not name: - errors.append(f"Missing package name in command: {command}") - if name and version: - pass + + if tool in ("pip", "conda"): + if tool == "pip": + pip_args = command[1:] + skip_next = False + + for item in pip_args: + if skip_next: + skip_next = False + continue + + if item.startswith("--index-url") or item.startswith( + "--extra-index-url" + ): + skip_next = True + continue + + if item.startswith("git+"): + continue + + if not version_pin_pattern.match(item): + errors.append( + f"Package '{item}' in command '{command}' is not version-pinned (e.g., 'package==1.0.0')." + ) + + elif tool == "conda" and not version: + errors.append( + f"Package in command '{command}' does not have a valid pinned version " + f"(should be in the format ['conda', 'package_name', 'x.y.z'])." + ) + return errors def _run_performance_check(self, n): @@ -445,5 +495,5 @@ def _run_performance_check(self, n): return Result(False, f"Error serving model: {process.stderr.strip()}") execution_time = time.time() - start_time return Result( - True, f"{n} predictions executed in {execution_time:.2f} seconds." + True, f"{n} predictions executed in {execution_time:.2f} seconds. \n" ) diff --git a/ersilia/publish/test.py b/ersilia/publish/test.py index 255c4b30b..e288a8cf1 100644 --- a/ersilia/publish/test.py +++ b/ersilia/publish/test.py @@ -1,15 +1,22 @@ import csv +import docker import json import os +import re +import numpy as np import subprocess +import zipfile import sys +import warnings import tempfile -import time -import types +import traceback +from pathlib import Path +import yaml +import sys from dataclasses import dataclass -from datetime import datetime from enum import Enum -from typing import List +from typing import Callable +from typing import List, Any # ruff: noqa MISSING_PACKAGES = False @@ -22,12 +29,10 @@ except ImportError: MISSING_PACKAGES = True # ruff: enable -import click from .. import ErsiliaBase, throw_ersilia_exception from ..default import ( DOCKERFILE_FILE, - INFORMATION_FILE, INSTALL_YAML_FILE, METADATA_JSON_FILE, METADATA_YAML_FILE, @@ -35,13 +40,26 @@ PACK_METHOD_FASTAPI, PREDEFINED_EXAMPLE_FILES, RUN_FILE, + EOS_TMP, + GITHUB_ORG, + S3_BUCKET_URL_ZIP, + DOCKERHUB_ORG, ) from ..hub.fetch.actions.template_resolver import TemplateResolver from ..io.input import ExampleGenerator +from ..hub.content.card import ModelCard +from ..utils.download import GitHubDownloader, S3Downloader from ..utils.conda import SimpleConda from ..utils.exceptions_utils import test_exceptions as texc -from ..utils.terminal import run_command_check_output +from ..utils.docker import SimpleDocker +from ..utils.logging import make_temp_dir +from ..utils.spinner import show_loader +from ..utils.hdf5 import Hdf5DataLoader +from ..utils.terminal import run_command_check_output, yes_no_input from .inspect import ModelInspector +from ..cli import echo + +warnings.filterwarnings("ignore", message="Using slow pure-python SequenceMatcher.*") class Options(Enum): @@ -52,9 +70,33 @@ class Options(Enum): NUM_SAMPLES = 5 BASE = "base" OUTPUT_CSV = "result.csv" + INPUT_CSV = "input.csv" OUTPUT1_CSV = "output1.csv" OUTPUT2_CSV = "output2.csv" - LEVEL_DEEP = "deep" + OUTPUT_FILES = [ + "file.csv", + "file.h5", + "file.json", + ] + INPUT_TYPES = ["str", "list", "csv"] + + +class Checks(Enum): + """ + Enum for different check types. + """ + + MODEL_CONSISTENCY = "Check Consistency of Model Output" + IMAGE_SIZE = "Image Size Mb" + PREDEFINED_EXAMPLE = "Check Predefined Example Input" + ENV_SIZE = "Environment Size Mb" + DIR_SIZE = "Directory Size Mb" + # messages + SIZE_CACL_SUCCESS = "Size Successfully Calculated" + SIZE_CACL_FAILED = "Size Calculation Failed" + INCONSISTENCY = "Inconsistent Output Detected" + CONSISTENCY = "Model Output Was Consistent" + RUN_BASH = "RMSE-MEAN" class TableType(Enum): @@ -62,11 +104,17 @@ class TableType(Enum): Enum for different table types. """ - MODEL_INFORMATION_CHECKS = "Model Information Checks" + MODEL_INFORMATION_CHECKS = "Model Metadata Checks" MODEL_FILE_CHECKS = "Model File Checks" MODEL_DIRECTORY_SIZES = "Model Directory Sizes" + MODEL_ENV_SIZES = "Model Environment Sizes" RUNNER_CHECKUP_STATUS = "Runner Checkup Status" FINAL_RUN_SUMMARY = "Test Run Summary" + DEPENDECY_CHECK = "Dependency Check" + COMPUTATIONAL_PERFORMANCE = "Computational Performance Summary" + SHALLOW_CHECK_SUMMARY = "Validation and Size Check Summary" + CONSISTENCY_BASH = "Consistency Summary Between Ersilia and Bash Execution Outputs" + MODEL_OUTPUT = "Model Output Content Validation Summary" INSPECT_SUMMARY = "Inspect Summary" @@ -88,7 +136,7 @@ class TableConfig: title="Model File Checks", headers=["Check", "Status"] ), TableType.MODEL_DIRECTORY_SIZES: TableConfig( - title="Model Directory Sizes", headers=["Dest dir", "Env Dir"] + title="Model Directory Sizes", headers=["Check", "Size"] ), TableType.RUNNER_CHECKUP_STATUS: TableConfig( title="Runner Checkup Status", @@ -97,6 +145,24 @@ class TableConfig: TableType.FINAL_RUN_SUMMARY: TableConfig( title="Test Run Summary", headers=["Check", "Status"] ), + TableType.DEPENDECY_CHECK: TableConfig( + title="Dependency Check", headers=["Check", "Status"] + ), + TableType.COMPUTATIONAL_PERFORMANCE: TableConfig( + title="Computational Performance Summary", headers=["Check", "Status"] + ), + TableType.SHALLOW_CHECK_SUMMARY: TableConfig( + title="Validation and Size Check Results", + headers=["Check", "Details", "Status"], + ), + TableType.MODEL_OUTPUT: TableConfig( + title="Model Output Content Validation Summary", + headers=["Check", "Detail", "Status"], + ), + TableType.CONSISTENCY_BASH: TableConfig( + title="Consistency Summary Between Ersilia and Bash Execution Outputs", + headers=["Check", "Result", "Status"], + ), TableType.INSPECT_SUMMARY: TableConfig( title="Inspect Summary", headers=["Check", "Status"] ), @@ -123,77 +189,6 @@ def __str__(self): return f"[{self.color}]{self.icon} {self.label}[/{self.color}]" -# fmt: off -class TestResult(Enum): - """ - Enum for test results. - """ - DATE_TIME_RUN = ( - "Date and Time Run", - lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S") - ) - TIME_ELAPSED = ( - "Time to Run Tests (seconds)", - lambda elapsed: elapsed - ) - BASIC_CHECKS = ( - "Basic Checks Passed", - lambda svc: svc.information_check - ) - SINGLE_INPUT = ( - "Single Input Run Without Error", - lambda svc: svc.single_input - ) - EXAMPLE_INPUT = ( - "Example Input Run Without Error", - lambda svc: svc.example_input - ) - CONSISTENT_OUTPUT = ( - "Outputs Consistent", - lambda svc: svc.consistent_output - ) - BASH_RUN = ( - "Bash Run Without Error", - lambda run_bash: run_bash, - ) -# fmt: on - def __init__(self, key, value_function): - self.key = key - self.value_function = value_function - - @classmethod - def generate_results(cls, checkup_service, elapsed_time, run_using_bash): - """ - Generate test results. - - Parameters - ---------- - checkup_service : object - The checkup service. - elapsed_time : float - The elapsed time. - run_using_bash : bool - Whether to run using bash. - - Returns - ------- - dict - The generated results. - """ - results = {} - for test in cls: - func_args = {} - if "svc" in test.value_function.__code__.co_varnames: - func_args["svc"] = checkup_service - if "elapsed" in test.value_function.__code__.co_varnames: - func_args["elapsed"] = elapsed_time - if "run_bash" in test.value_function.__code__.co_varnames: - func_args["run_bash"] = run_using_bash - - value = test.value_function(**func_args) - results[test.key] = value - return results - class CheckStrategy: """ Execuetd a strategy for checking inspect commands. @@ -207,6 +202,7 @@ class CheckStrategy: details_key : str The key for details. """ + def __init__(self, check_function, success_key, details_key): self.check_function = check_function self.success_key = success_key @@ -242,6 +238,8 @@ class InspectService(ErsiliaBase): Directory where the model is located. model : str, optional Model identifier. + remote : bool, optional + Flag indicating whether the model is remote. config_json : str, optional Path to the configuration JSON file. credentials_json : str, optional @@ -251,21 +249,36 @@ class InspectService(ErsiliaBase): -------- .. code-block:: python - inspector = InspectService(dir="/path/to/model", model="model_id") + inspector = InspectService( + dir="/path/to/model", model="model_id" + ) results = inspector.run() """ - def __init__(self, dir: str = None, model: str = None, remote: bool = False, config_json: str = None, credentials_json: str = None): + def __init__( + self, + dir: str, + model: str, + remote: bool = False, + config_json: str = None, + credentials_json: str = None, + ): super().__init__(config_json, credentials_json) self.dir = dir self.model = model self.remote = remote + self.resolver = TemplateResolver(model_id=model, repo_path=self.dir) - def run(self) -> dict: + def run(self, check_keys: list = None) -> dict: """ Run the inspection checks on the specified model. + Parameters + ---------- + check_keys : list, optional + A list of check keys to execute. If None, all checks will be executed. + Returns ------- dict @@ -275,53 +288,100 @@ def run(self) -> dict: ------ ValueError If the model is not specified. + KeyError + If any of the specified keys do not exist. """ - if not self.model: - raise ValueError("Model must be specified.") + + def _transform_key(value): + if value is True: + return str(STATUS_CONFIGS.PASSED) + elif value is False: + return str(STATUS_CONFIGS.FAILED) + return value inspector = ModelInspector(self.model, self.dir) checks = self._get_checks(inspector) output = {} - for strategy in checks: - if strategy.check_function: - output.update(strategy.execute()) + if check_keys: + for key in check_keys: + try: + strategy = checks.get(key) + if strategy.check_function: + output.update(strategy.execute()) + except KeyError: + raise KeyError(f"Check '{key}' does not exist.") + else: + for key, strategy in checks.all(): + if strategy.check_function: + output.update(strategy.execute()) + output = { + " ".join(word.capitalize() for word in k.split("_")): _transform_key(v) + for k, v in output.items() + } + + output = [(key, value) for key, value in output.items()] return output - def _get_checks(self, inspector: ModelInspector) -> list: - return [ - CheckStrategy( + def _get_checks(self, inspector: ModelInspector) -> dict: + def create_check(check_fn, key, details): + return lambda: CheckStrategy(check_fn, key, details) + + dependency_check = ( + "Dockerfile" if self.resolver.is_bentoml() else "Install_YAML" + ) + checks = { + "is_github_url_available": create_check( inspector.check_repo_exists if self.remote else lambda: None, "is_github_url_available", "is_github_url_available_details", ), - CheckStrategy( + "complete_metadata": create_check( inspector.check_complete_metadata if self.remote else lambda: None, "complete_metadata", "complete_metadata_details", ), - CheckStrategy( + "complete_folder_structure": create_check( inspector.check_complete_folder_structure, "complete_folder_structure", "complete_folder_structure_details", ), - CheckStrategy( + "docker_check": create_check( inspector.check_dependencies_are_valid, - "docker_check", - "docker_check_details", + f"{dependency_check}_check", + "check_details", ), - CheckStrategy( + "computational_performance_tracking": create_check( inspector.check_computational_performance, "computational_performance_tracking", "computational_performance_tracking_details", ), - CheckStrategy( + "extra_files_check": create_check( inspector.check_no_extra_files if self.remote else lambda: None, "extra_files_check", "extra_files_check_details", ), - ] + } + + class LazyChecks: + def __init__(self, checks): + self._checks = checks + self._loaded = {} + + def get(self, key): + if key not in self._loaded: + if key not in self._checks: + raise KeyError(f"Check '{key}' does not exist.") + self._loaded[key] = self._checks[key]() + return self._loaded[key] + + def all(self): + for key in self._checks.keys(): + yield key, self.get(key) + + return LazyChecks(checks) + class SetupService: """ @@ -333,24 +393,94 @@ class SetupService: Identifier of the model. dir : str Directory where the model repository will be cloned. - logger : logging.Logger + from_github : bool + Flag indicating whether to fetch the repository from GitHub. + from_s3 : bool + Flag indicating whether to fetch the repository from S3. + logger : Any Logger for logging messages. - remote : bool - Flag indicating whether to fetch the repository from a remote source. """ BASE_URL = "https://github.com/ersilia-os/" - def __init__(self, model_id: str, dir: str, logger, remote: bool): + def __init__( + self, + model_id: str, + dir: str, + from_github: bool, + from_s3: bool, + logger: Any, + ): self.model_id = model_id self.dir = dir self.logger = logger - self.remote = remote + self.from_github = from_github + self.from_s3 = from_s3 + + self.mc = ModelCard() + self.metadata = self.mc.get(model_id) + self.s3 = self.metadata.get("card", {}).get("S3") or self.metadata.get( + "metadata", {} + ).get("S3") self.repo_url = f"{self.BASE_URL}{self.model_id}" + self.overwrite = self._handle_overwrite() + self.github_down = GitHubDownloader(overwrite=self.overwrite) + self.s3_down = S3Downloader() self.conda = SimpleConda() + def _handle_overwrite(self) -> bool: + if os.path.exists(self.dir): + self.logger.info(f"Directory {self.dir} already exists.") + return yes_no_input( + f"Directory {self.dir} already exists. Do you want to overwrite it? [Y/n]", + default_answer="n", + ) + return False + + def _download_s3(self): + if not self.overwrite and os.path.exists(self.dir): + self.logger.info("Skipping S3 download as user chose not to overwrite.") + return + + tmp_file = os.path.join(make_temp_dir("ersilia-"), f"{self.model_id}.zip") + + self.logger.info(f"Downloading model from S3 to temporary file: {tmp_file}") + self.s3_down.download_from_s3( + bucket_url=S3_BUCKET_URL_ZIP, + file_name=f"{self.model_id}.zip", + destination=tmp_file, + ) + + self.logger.info(f"Extracting model to: {self.dir}") + with zipfile.ZipFile(tmp_file, "r") as zip_ref: + zip_ref.extractall(EOS_TMP) + + def _download_github(self): + try: + if not os.path.exists(EOS_TMP): + self.logger.info(f"Path does not exist. Creating: {EOS_TMP}") + os.makedirs(EOS_TMP, exist_ok=True) + except OSError as e: + self.logger.error(f"Failed to create directory {EOS_TMP}: {e}") + + self.logger.info(f"Cloning repository from GitHub to: {EOS_TMP}") + self.github_down.clone( + org=GITHUB_ORG, + repo=self.model_id, + destination=self.dir, + ) + + def get_model(self): + if self.from_s3: + self._download_s3() + + if self.from_github: + self._download_github() + @staticmethod - def run_command(command: str, logger, capture_output: bool = False, shell: bool = True) -> str: + def run_command( + command: str, logger, capture_output: bool = False, shell: bool = True + ) -> str: """ Run a shell command. @@ -383,7 +513,7 @@ def run_command(command: str, logger, capture_output: bool = False, shell: bool stderr=subprocess.PIPE, text=True, check=True, - shell=shell + shell=shell, ) return result.stdout else: @@ -392,20 +522,20 @@ def run_command(command: str, logger, capture_output: bool = False, shell: bool stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - shell=shell + shell=shell, ) stdout_lines, stderr_lines = [], [] - for line in iter(process.stdout.readline, ''): + for line in iter(process.stdout.readline, ""): if line.strip(): stdout_lines.append(line.strip()) logger.info(line.strip()) - for line in iter(process.stderr.readline, ''): + for line in iter(process.stderr.readline, ""): if line.strip(): stderr_lines.append(line.strip()) - logger.error(line.strip()) + logger.info(line.strip()) process.wait() if process.returncode != 0: @@ -423,41 +553,12 @@ def run_command(command: str, logger, capture_output: bool = False, shell: bool if e.output: logger.debug(f"Output: {e.output.strip()}") if e.stderr: - logger.debug(f"Error: {e.stderr.strip()}") + logger.error(f"Error: {e.stderr.strip()}") sys.exit(1) except Exception as e: logger.debug(f"Unexpected error: {e}") sys.exit(1) - def fetch_repo(self): - """ - Fetch the model repository from the remote source if the remote flag is set. - """ - if self.remote and not os.path.exists(self.dir): - out = SetupService.run_command( - f"git clone {self.repo_url}", - self.logger, - ) - self.logger.info(out) - - def check_conda_env(self): - """ - Check if the Conda environment for the model exists. - - Raises - ------ - Exception - If the Conda environment does not exist. - """ - if self.conda.exists(self.model_id): - self.logger.debug( - f"Conda environment '{self.model_id}' already exists." - ) - else: - raise Exception( - f"Conda virtual environment not found for {self.model_id}" - ) - @staticmethod def get_conda_env_location(model_id: str, logger) -> str: """ @@ -482,9 +583,7 @@ def get_conda_env_location(model_id: str, logger) -> str: """ try: result = SetupService.run_command( - "conda env list", - logger=logger, - capture_output=True + "conda env list", logger=logger, capture_output=True ) for line in result.splitlines(): if line.startswith("#") or not line.strip(): @@ -499,6 +598,7 @@ def get_conda_env_location(model_id: str, logger) -> str: return None + class IOService: """ Service for handling input/output operations related to model testing. @@ -507,14 +607,6 @@ class IOService: ---------- logger : logging.Logger Logger for logging messages. - dest_dir : str - Destination directory for storing model-related files. - model_path : str - Path to the model. - bundle_path : str - Path to the model bundle. - bentoml_path : str - Path to the BentoML files. model_id : str Identifier of the model. dir : str @@ -524,13 +616,14 @@ class IOService: -------- .. code-block:: python - ios = IOService(logger=logger, dest_dir="/path/to/dest", model_path="/path/to/model", - bundle_path="/path/to/bundle", bentoml_path="/path/to/bentoml", - model_id="model_id", dir="/path/to/dir") + ios = IOService( + logger=logger, + model_id="model_id", + dir="/path/to/dir", + ) ios.read_information() """ - # Required files RUN_FILE = f"model/framework/{RUN_FILE}" BENTOML_FILES = [ DOCKERFILE_FILE, @@ -539,7 +632,7 @@ class IOService: "src/service.py", "pack.py", "README.md", - "LICENSE" + "LICENSE", ] ERSILIAPACK_FILES = [ @@ -552,72 +645,109 @@ class IOService: "LICENSE", ] - def __init__(self, logger, dest_dir: str, model_path: str, bundle_path: str, bentoml_path: str, model_id: str, dir: str): + def __init__(self, logger, model_id: str, dir: str): self.logger = logger self.model_id = model_id self.dir = dir self.model_size = 0 self.console = Console() self.check_results = [] - self._model_path = model_path - self._bundle_path = bundle_path - self._bentoml_path = bentoml_path - self._dest_dir = dest_dir + self.simple_docker = SimpleDocker() + self.resolver = TemplateResolver(model_id=model_id, repo_path=self.dir) - def _run_check(self, check_function, data, check_name: str, additional_info=None) -> bool: + def _run_check( + self, check_function, data, check_name: str, additional_info=None + ) -> bool: try: if additional_info is not None: check_function(additional_info) else: check_function(data) - self.check_results.append(( - check_name, - str(STATUS_CONFIGS.PASSED) - )) + self.check_results.append((check_name, str(STATUS_CONFIGS.PASSED))) return True except Exception as e: - self.logger.error( - f"Check '{check_name}' failed: {e}" - ) - self.check_results.append(( - check_name, - str(STATUS_CONFIGS.FAILED) - )) + self.logger.error(f"Check '{check_name}' failed: {e}") + self.check_results.append((check_name, str(STATUS_CONFIGS.FAILED))) return False - def _generate_table(self, title: str, headers: List[str], rows: List[List[str]], large_table: bool = False, merge: bool = False): + def _get_metadata(self): + path = METADATA_JSON_FILE if self.resolver.is_bentoml() else METADATA_YAML_FILE + path = os.path.join(self.dir, path) + + with open(path, "r") as file: + if path.endswith(".json"): + data = json.load(file) + elif path.endswith((".yml", ".yaml")): + data = yaml.safe_load(file) + else: + raise ValueError(f"Unsupported file format: {path}") + return data + + def collect_and_save_json(self, results, output_file): + """ + Helper function to collect JSON results and save them to a file. + """ + aggregated_json = {} + for result in results: + aggregated_json.update(result) + + with open(output_file, "w") as f: + json.dump(aggregated_json, f, indent=4) + + def _create_json_data(self, rows, key): + def sanitize_name(name): + return re.sub(r"[ \-./]", "_", str(name).lower()) + + def parse_status(status): + if isinstance(status, str): + status = re.sub( + r"[-+]?\d*\.\d+|\d+", lambda m: str(float(m.group())), status + ) + if "[green]✔" in status: + return True + elif "[red]✘" in status: + return False + else: + return re.sub(r"\[.*?\]", "", status).strip() + return status + + def parse_performance(status): + return { + f"pred_{match[0]}": float(match[1]) + for match in re.findall( + r"(\d+) predictions executed in (\d+\.\d{2}) seconds. \n", status + ) + } + + key = re.sub(r" ", "_", key.lower()) + json_data = {} + + for row in rows: + check_name = sanitize_name(row[0]) + check_status = row[-1] + + if check_name == "computational_performance_tracking_details": + json_data[check_name] = parse_performance(check_status) + else: + json_data[check_name] = parse_status(check_status) + + return {key: json_data} + + def _generate_table(self, title, headers, rows, large_table=False, merge=False): f_col_width = 30 if large_table else 30 l_col_width = 50 if large_table else 10 d_col_width = 30 if not large_table else 20 table = Table( - title=Text( - title, - style="bold light_green" - ), + title=Text(title, style="bold light_green"), border_style="light_green", show_lines=True, ) - table.add_column( - headers[0], - justify="left", - width=f_col_width, - style="bold" - ) + table.add_column(headers[0], justify="left", width=f_col_width, style="bold") for header in headers[1:-1]: - table.add_column( - header, - justify="center", - width=d_col_width, - style="bold" - ) - table.add_column( - headers[-1], - justify="right", - width=l_col_width, - style="bold" - ) + table.add_column(header, justify="center", width=d_col_width, style="bold") + table.add_column(headers[-1], justify="right", width=l_col_width, style="bold") prev_value = None for row in rows: @@ -630,12 +760,16 @@ def _generate_table(self, title: str, headers: List[str], rows: List[List[str]], styled_row = [ Text(first_col, style="bold"), *[str(cell) for cell in row[1:-1]], - row[-1] + row[-1], ] table.add_row(*styled_row) + json_data = self._create_json_data(rows, title) + self.console.print(table) + return json_data + @staticmethod def get_model_type(model_id: str, repo_path: str) -> str: """ @@ -653,10 +787,7 @@ def get_model_type(model_id: str, repo_path: str) -> str: str The type of the model (e.g., PACK_METHOD_BENTOML, PACK_METHOD_FASTAPI). """ - resolver = TemplateResolver( - model_id=model_id, - repo_path=repo_path - ) + resolver = TemplateResolver(model_id=model_id, repo_path=repo_path) if resolver.is_bentoml(): return PACK_METHOD_BENTOML elif resolver.is_fastapi(): @@ -678,18 +809,13 @@ def get_file_requirements(self) -> List[str]: ValueError If the model type is unsupported. """ - type = IOService.get_model_type( - model_id=self.model_id, - repo_path=self.dir - ) + type = IOService.get_model_type(model_id=self.model_id, repo_path=self.dir) if type == PACK_METHOD_BENTOML: return self.BENTOML_FILES elif type == PACK_METHOD_FASTAPI: return self.ERSILIAPACK_FILES else: - raise ValueError( - f"Unsupported model type: {type}" - ) + raise ValueError(f"Unsupported model type: {type}") def read_information(self) -> dict: """ @@ -705,43 +831,15 @@ def read_information(self) -> dict: FileNotFoundError If the information file does not exist. """ - file = os.path.join( - self._dest_dir, - self.model_id, - INFORMATION_FILE - ) + file = os.path.join(EOS_TMP, self.model_id, METADATA_JSON_FILE) if not os.path.exists(file): raise FileNotFoundError( f"Information file does not exist for model {self.model_id}" - ) + ) with open(file, "r") as f: return json.load(f) - def print_output(self, result, output): - """ - Print the output of a result. - - Parameters - ---------- - result : any - The result to print. - output : file-like object - The output file to write to. - """ - def write_output(data): - if output is not None: - with open(output.name, "w") as file: - json.dump(data, file) - else: - self.logger.debug(json.dumps(data, indent=4)) - - if isinstance(result, types.GeneratorType): - for r in result: - write_output(r if r is not None else "Something went wrong") - else: - self.logger.debug(result) - - def get_conda_env_size(self) -> int: + def get_conda_env_size(self): """ Get the size of the Conda environment for the model. @@ -756,10 +854,7 @@ def get_conda_env_size(self) -> int: If there is an error calculating the size. """ try: - loc = SetupService.get_conda_env_location( - self.model_id, - self.logger - ) + loc = SetupService.get_conda_env_location(self.model_id, self.logger) return self.calculate_directory_size(loc) except Exception as e: self.logger.error( @@ -786,29 +881,140 @@ def calculate_directory_size(self, path: str) -> int: ["du", "-sm", path], logger=self.logger, capture_output=True, - shell=False + shell=False, ) - size = int(size_output.split()[0]) + size = float(size_output.split()[0]) return size except Exception as e: - self.logger.error( - f"Error calculating directory size for {path}: {e}" - ) + self.logger.error(f"Error calculating directory size for {path}: {e}") return 0 + def calculate_image_size(self, tag="latest"): + """ + Calculate the size of a Docker image. + + Parameters + ---------- + tag : str, optional + The tag of the Docker image (default is 'latest'). + + Returns + ------- + str + The size of the Docker image. + """ + image_name = f"{DOCKERHUB_ORG}/{self.model_id}:{tag}" + client = docker.from_env() + try: + image = client.images.get(image_name) + size_bytes = image.attrs["Size"] + size_mb = size_bytes / (1024**2) + return f"{size_mb:.2f} MB", Checks.SIZE_CACL_SUCCESS.value + except docker.errors.ImageNotFound: + return f"Image '{image_name}' not found.", Checks.SIZE_CACL_FAILED.value + except Exception as e: + return f"An error occurred: {e}", Checks.SIZE_CACL_FAILED.value + @throw_ersilia_exception() - def get_directories_sizes(self) -> tuple: + def get_directories_sizes(self) -> str: """ - Get the sizes of the model directory and the Conda environment directory. + Get the sizes of the model directory. Returns ------- - tuple - A tuple containing the sizes of the model directory and the Conda environment directory in megabytes. + str + A string of containing size of the model directory """ dir_size = self.calculate_directory_size(self.dir) + dir_size = f"{dir_size:.2f}" + return dir_size + + @throw_ersilia_exception() + def get_env_sizes(self) -> str: + """ + Get the sizes of the Conda environment directory. + + Returns + ------- + str + A string of containing size of the model environment + """ env_size = self.get_conda_env_size() - return dir_size, env_size + env_size = f"{env_size:.2f}" + return env_size + + def _extract_size(self, data, key="validation_and_size_check_results"): + sizes, keys = {}, ("environment_size_mb", "image_size_mb") + validation_results = next((item.get(key) for item in data if key in item), None) + if validation_results: + if keys[0] in validation_results: + env_size = validation_results[keys[0]] + self.logger.info(f"Environment Size: {env_size}") + sizes["Environment Size"] = float(env_size) + if keys[1] in validation_results: + img_size = validation_results[keys[1]] + self.logger.info(f"Image Size: {img_size}") + sizes["Image Size"] = float(img_size) + return sizes + + def _extract_execution_times(self, data, key="computational_performance_summary"): + self.logger.info("Performance Extraction is started") + performance = { + "Computational Performance 1": None, + "Computational Performance 10": None, + "Computational Performance 100": None, + } + + summary = next((item.get(key) for item in data if key in item), None) + if summary: + preds = summary.get("computational_performance_tracking_details") + performance["Computational Performance 1"] = float(preds.get("pred_1")) + performance["Computational Performance 10"] = float(preds.get("pred_10")) + performance["Computational Performance 100"] = float(preds.get("pred_100")) + + return performance + + def update_metadata(self, json_data): + """ + Processes JSON/YAML metadata to extract size and performance info and then updates them. + + Parameters + ---------- + json_data : dict + Report data from the command output. + + Returns + ------- + dict + Updated metadata containing computed performance and size information. + """ + sizes = self._extract_size(json_data) + exec_times = self._extract_execution_times(json_data) + metadata = self._get_metadata() + metadata.update(sizes) + metadata.update(exec_times) + + self._save_file( + metadata, + ) + + def _save_file(self, metadata): + path = METADATA_JSON_FILE if self.resolver.is_bentoml() else METADATA_YAML_FILE + path = os.path.join(self.dir, path) + with open(path, "w") as file: + if path.endswith(".json"): + json.dump(metadata, file, indent=4, ensure_ascii=False) + elif path.endswith((".yml", ".yaml")): + yaml.dump( + metadata, + file, + default_flow_style=False, + sort_keys=False, + allow_unicode=True, + ) + else: + raise ValueError(f"Unsupported file format: {path}") + class CheckService: """ @@ -820,10 +1026,12 @@ class CheckService: Logger for logging messages. model_id : str Identifier of the model. - dest_dir : str - Destination directory for storing model-related files. dir : str Directory where the model repository is located. + from_github : bool + Flag indicating whether to fetch the repository from GitHub. + from_dockerhub : bool + Flag indicating whether to fetch the repository from DockerHub. ios : IOService Instance of IOService for handling input/output operations. @@ -831,8 +1039,14 @@ class CheckService: -------- .. code-block:: python - check_service = CheckService(logger=logger, model_id="model_id", dest_dir="/path/to/dest", - dir="/path/to/dir", ios=ios) + check_service = CheckService( + logger=logger, + model_id="model_id", + dir="/path/to/dir", + from_github=True, + from_dockerhub=False, + ios=ios, + ) check_service.check_files() """ @@ -861,43 +1075,47 @@ class CheckService: "Text", } - INPUT_SHAPE = { - "Single", - "Pair", - "List", - "Pair of Lists", - "List of Lists" - } + INPUT_SHAPE = {"Single", "Pair", "List", "Pair of Lists", "List of Lists"} - OUTPUT_SHAPE = { - "Single", - "List", - "Flexible List", - "Matrix", - "Serializable Object" - } + OUTPUT_SHAPE = {"Single", "List", "Flexible List", "Matrix", "Serializable Object"} - def __init__(self, logger, model_id: str, dest_dir: str, dir: str, ios: IOService): + def __init__( + self, + logger: Any, + model_id: str, + dir: str, + from_github: bool, + from_dockerhub: bool, + ios: IOService, + ): self.logger = logger self.model_id = model_id - self._dest_dir = dest_dir self.dir = dir + self.from_github = from_github + self.from_dockerhub = from_dockerhub self._run_check = ios._run_check self._generate_table = ios._generate_table - self._print_output = ios.print_output self.get_file_requirements = ios.get_file_requirements self.console = ios.console self.check_results = ios.check_results - self.information_check = False - self.single_input = False - self.example_input = False - self.consistent_output = False + self.resolver = TemplateResolver(model_id=model_id, repo_path=self.dir) + + def _get_metadata(self): + path = METADATA_JSON_FILE if self.resolver.is_bentoml() else METADATA_YAML_FILE + path = os.path.join(self.dir, path) + + with open(path, "r") as file: + if path.endswith(".json"): + data = json.load(file) + elif path.endswith((".yml", ".yaml")): + data = yaml.safe_load(file) + else: + raise ValueError(f"Unsupported file format: {path}") + return data def _check_file_existence(self, path): if not os.path.exists(os.path.join(self.dir, path)): - raise FileNotFoundError( - f"File '{path}' does not exist." - ) + raise FileNotFoundError(f"File '{path}' does not exist.") def check_files(self): """ @@ -906,153 +1124,199 @@ def check_files(self): requirements = self.get_file_requirements() for file in requirements: self.logger.debug(f"Checking file: {file}") - self._run_check( - self._check_file_existence, - None, - f"File: {file}", - file - ) + self._run_check(self._check_file_existence, None, f"{file}", file) def _check_model_id(self, data): self.logger.debug("Checking model ID...") - if data["card"]["Identifier"] != self.model_id: + if data["Identifier"] != self.model_id: raise texc.WrongCardIdentifierError(self.model_id) - def _check_model_slug(self, data): self.logger.debug("Checking model slug...") - if not data["card"]["Slug"]: + if not data["Slug"]: raise texc.EmptyField("slug") - def _check_model_description(self, data): self.logger.debug("Checking model description...") - if not data["card"]["Description"]: + if not data["Description"]: raise texc.EmptyField("Description") + def _check_model_tag(self, data): + self.logger.debug("Checking model tag...") + if not data["Tag"]: + raise texc.EmptyField("Tag") + + def _check_model_source_code(self, data): + self.logger.debug("Checking model source code...") + if not data["Source Code"]: + raise texc.EmptyField("Source Code") + + def _check_model_source_title(self, data): + self.logger.debug("Checking model title...") + if not data["Title"]: + raise texc.EmptyField("Title") + + def _check_model_status(self, data): + self.logger.debug("Checking model status...") + if not data["Status"]: + raise texc.EmptyField("Status") + + def _check_model_contributor(self, data): + self.logger.debug("Checking model contributor...") + if not data["Contributor"]: + raise texc.EmptyField("Contributor") + + def _check_model_interpret(self, data): + self.logger.debug("Checking model interpretation...") + if not data["Interpretation"]: + raise texc.EmptyField("Interpretation") + + def _check_model_dockerhub_url(self, data): + key = "DockerHub" + self.logger.info(f"Data: {data}") + self.logger.debug(f"Checking {key} URL field..") + if key in data: + self.logger.debug(f"Checking {key} URL field..") + if not data[key]: + self.logger.debug(f"Checking {key} URL field..") + raise texc.EmptyField(key) + else: + self.logger.debug(f"Checking {key} URL field..") + raise texc.EmptyKey(key) + + def _check_model_s3_url(self, data): + key = "S3" + self.logger.debug(f"Checking {key} URL field..") + if key in data: + if not data[key]: + raise texc.EmptyField(key) + else: + raise texc.EmptyKey(key) + + def _check_model_arch(self, data): + key = "Docker Architecture" + self.logger.debug(f"Checking {key} field..") + if key in data: + if not data[key]: + raise texc.EmptyField(key) + else: + raise texc.EmptyKey(key) + + def _check_model_publication(self, data): + key = "Publication" + self.logger.debug(f"Checking {key} field..") + if key in data: + if not data[key]: + raise texc.EmptyField(key) + else: + raise texc.EmptyKey(key) + def _check_model_task(self, data): self.logger.debug("Checking model task...") - raw_tasks = data.get("card", {}).get("Task", "") + raw_tasks = data.get("Task") if isinstance(raw_tasks, str): - tasks = [ - task.strip() - for task - in raw_tasks.split(",") - if task.strip() - ] + tasks = [task.strip() for task in raw_tasks.split(",") if task.strip()] elif isinstance(raw_tasks, list): tasks = [ task.strip() - for task - in raw_tasks + for task in raw_tasks if isinstance(task, str) and task.strip() ] else: - raise texc.InvalidEntry( - "Task", - message="Task field must be a string or list." - ) + raise texc.InvalidEntry("Task") if not tasks: - raise texc.InvalidEntry( - "Task", - message="Task field is missing or empty." - ) + raise texc.InvalidEntry("Task") invalid_tasks = [task for task in tasks if task not in self.MODEL_TASKS] if invalid_tasks: - raise texc.InvalidEntry( - "Task", message=f"Invalid tasks: {', '.join(invalid_tasks)}" - ) + raise texc.InvalidEntry("Task") self.logger.debug("All tasks are valid.") def _check_model_output(self, data): self.logger.debug("Checking model output...") - raw_outputs = data.get("card", {}).get("Output", "") or data.get("metadata", {}).get("Output", "") + raw_outputs = data.get("Output") if isinstance(raw_outputs, str): outputs = [ - output.strip() - for output - in raw_outputs.split(",") - if output.strip() + output.strip() for output in raw_outputs.split(",") if output.strip() ] elif isinstance(raw_outputs, list): outputs = [ output.strip() - for output - in raw_outputs + for output in raw_outputs if isinstance(output, str) and output.strip() ] else: - raise texc.InvalidEntry( - "Output", - message="Output field must be a string or list." - ) + raise texc.InvalidEntry("Output") if not outputs: - raise texc.InvalidEntry( - "Output", - message="Output field is missing or empty." - ) + raise texc.InvalidEntry("Output") invalid_outputs = [ - output - for output - in outputs - if output not in self.MODEL_OUTPUT + output for output in outputs if output not in self.MODEL_OUTPUT ] if invalid_outputs: - raise texc.InvalidEntry( - "Output", - message=f"Invalid outputs: {' '.join(invalid_outputs)}" - ) + raise texc.InvalidEntry("Output") self.logger.debug("All outputs are valid.") def _check_model_input(self, data): self.logger.debug("Checking model input") - valid_inputs = [{"Compound"}, {"Protein"}, {"Text"}] + valid_inputs = ["Compound", "Protein", "Text"] - model_input = data.get("card", {}).get("Input") or data.get("metadata", {}).get("Input") + model_input = data.get("Input") + if isinstance(model_input, str): + model_input = [ + input.strip() for input in model_input.split(",") if input.strip() + ] + elif isinstance(model_input, list): + model_input = [ + input.strip() + for input in model_input + if isinstance(input, str) and input.strip() + ] + else: + raise texc.InvalidEntry("Input") + + if not model_input: + raise texc.InvalidEntry("Output") - if not model_input or set(model_input) not in valid_inputs: + invalid_inputs = [input for input in model_input if input not in valid_inputs] + if invalid_inputs: raise texc.InvalidEntry("Input") + self.logger.debug("All Inputs are valid.") + def _check_model_input_shape(self, data): self.logger.debug("Checking model input shape") - model_input_shape = ( - data.get("card", {}).get("Input Shape") or - data.get("metadata", {}).get("InputShape") - ) + model_input_shape = data.get("Input Shape") if model_input_shape not in self.INPUT_SHAPE: raise texc.InvalidEntry("Input Shape") def _check_model_output_type(self, data): self.logger.debug("Checking model output type...") - valid_output_types = [{"String"}, {"Float"}, {"Integer"}] + valid_output_types = ["String", "Float", "Integer"] + model_output_type = data.get("Output Type") model_output_type = ( - data.get("card", {}).get("Output Type") or - data.get("metadata", {}).get("OutputType") + model_output_type[0] + if isinstance(model_output_type, list) + else model_output_type ) - - if not model_output_type or set(model_output_type) not in valid_output_types: + if not model_output_type or model_output_type not in valid_output_types: raise texc.InvalidEntry("Output Type") def _check_model_output_shape(self, data): self.logger.debug("Checking model output shape...") - model_output_shape = ( - data.get("card", {}).get("Output Shape") or - data.get("metadata", {}).get("OutputShape") - ) + model_output_shape = data.get("Output Shape") if model_output_shape not in self.OUTPUT_SHAPE: raise texc.InvalidEntry("Output Shape") @throw_ersilia_exception() - def check_information(self, output): + def check_information(self): """ Perform various checks on the model information. @@ -1062,16 +1326,12 @@ def check_information(self, output): The output file to write to. """ self.logger.debug(f"Beginning checks for {self.model_id} model information") - file = os.path.join( - self._dest_dir, - self.model_id, - INFORMATION_FILE - ) - with open(file, "r") as f: - data = json.load(f) + data = self._get_metadata() self._run_check(self._check_model_id, data, "Model ID") self._run_check(self._check_model_slug, data, "Model Slug") + self._run_check(self._check_model_status, data, "Model Status") + self._run_check(self._check_model_source_title, data, "Model Title") self._run_check(self._check_model_description, data, "Model Description") self._run_check(self._check_model_task, data, "Model Task") self._run_check(self._check_model_input, data, "Model Input") @@ -1079,55 +1339,228 @@ def check_information(self, output): self._run_check(self._check_model_output, data, "Model Output") self._run_check(self._check_model_output_type, data, "Model Output Type") self._run_check(self._check_model_output_shape, data, "Model Output Shape") - - if output is not None: - self.information_check = True - - @throw_ersilia_exception() - def check_single_input(self, output, run_model, run_example): - """ - Check if the model can run with a single input to check if it has a value - in the produced output csv. - - Parameters - ---------- - output : file-like object - The output file to write to. - run_model : callable - Function to run the model. - run_example : callable - Function to generate example input. - """ - input = run_example( + self._run_check(self._check_model_interpret, data, "Model Interpretation") + self._run_check(self._check_model_tag, data, "Model Tag") + self._run_check(self._check_model_publication, data, "Model Publication") + self._run_check(self._check_model_source_code, data, "Model Source Code") + self._run_check(self._check_model_contributor, data, "Model Contributor") + self._run_check(self._check_model_dockerhub_url, data, "Model Dockerhub URL") + self._run_check(self._check_model_s3_url, data, "Model S3 URL") + self._run_check(self._check_model_arch, data, "Model Docker Architecture") + + def get_inputs(self, run_example, types): + samples = run_example( n_samples=Options.NUM_SAMPLES.value, file_name=None, simple=True, - try_predefined=False + try_predefined=False, ) - result = run_model( - input=input, - output=output, - batch=100 + samples = [sample["input"] for sample in samples] + if types == "str": + return samples[0] + if types == "list": + return json.dumps(samples) + if types == "csv": + run_example( + n_samples=Options.NUM_SAMPLES.value, + file_name=Options.INPUT_CSV.value, + simple=True, + try_predefined=False, + ) + return Options.INPUT_CSV.value + + def check_model_output_content(self, run_example, run_model): + status = [] + self.logger.debug("Checking model output...") + for inp_type in Options.INPUT_TYPES.value: + for i, output_file in enumerate(Options.OUTPUT_FILES.value): + inp_data = self.get_inputs(run_example, inp_type) + self.logger.debug(f"Input data: {inp_data}") + run_model(inputs=inp_data, output=output_file, batch=100) + _status = self.validate_file_content(output_file, inp_type) + status.append(_status) + return status + + def _is_invalid_value(self, value): + try: + if value is None: + return True + if isinstance(value, str): + if value.strip().lower() in {"", "nan", "null", "none"}: + return True + if isinstance(value, float) and (np.isnan(value) or np.isinf(value)): + return True + + if isinstance(value, (list, tuple)): + return any(self._is_invalid_value(item) for item in value) + if isinstance(value, dict): + return any(self._is_invalid_value(item) for item in value.values()) + if isinstance(value, (np.ndarray)): + return np.any(np.isnan(value)) or np.any(np.isinf(value)) + except Exception: + return True + return False + + def _check_csv(self, file_path, input_type="list"): + self.logger.debug(f"Checking CSV file: {file_path}") + error_details = [] + + with open(file_path, "r") as f: + reader = csv.reader(f) + rows = list(reader)[1:] + + for row_index, row in enumerate(rows, start=2): + for col_index, cell in enumerate(row, start=1): + self.logger.info(f"CSV content check for input type {input_type}") + try: + parsed_cell = ( + eval(cell) + if isinstance(cell, str) and cell.startswith(("[", "{")) + else cell + ) + except Exception as e: + parsed_cell = cell + + if self._is_invalid_value(parsed_cell): + self.logger.error(f"Invalid cell found: {cell}") + error_details.append( + f"Row {row_index}, Column {col_index}: {repr(cell)}" + ) + + if error_details: + return ( + f"{input_type.upper()}-CSV", + f"Invalid values found in CSV content: {error_details} issues detected.", + str(STATUS_CONFIGS.FAILED), + ) + + return ( + f"{input_type.upper()}-CSV", + "Valid Content", + str(STATUS_CONFIGS.PASSED), ) - def read_csv(file_path): - absolute_path = os.path.abspath(file_path) - if not os.path.exists(absolute_path): - raise FileNotFoundError(f"File not found: {absolute_path}") - with open(absolute_path, mode='r') as csv_file: - reader = csv.DictReader(csv_file) - return [row for row in reader] + def validate_file_content(self, file_path, input_type): + def check_json(file_path): + self.logger.debug(f"Checking JSON file: {file_path}") + error_details = [] - try: - csv_content = read_csv(output) - if csv_content: - self.single_input = True - except Exception as e: - self.logger.error(f"Error reading CSV content: {e}") - self._print_output(result, output) + try: + with open(file_path, "r") as f: + content = json.load(f) + + self.logger.debug(f"Content: {content}") + + def _validate_item(item, path="result"): + if self._is_invalid_value(item): + self.logger.error(f"Json content invalud value: {item}") + error_details.append(f"{repr(item)}") + elif isinstance(item, dict): + for key, value in item.items(): + _validate_item(value, f"{path} -> {key}") + elif isinstance(item, list): + for index, value in enumerate(item): + _validate_item(value, f"{path}[{index}]") + + _validate_item(content) + + if error_details: + return ( + f"{input_type.upper()}-JSON", + f"Invalid values found in JSON content: {error_details}.", + str(STATUS_CONFIGS.FAILED), + ) + + return ( + f"{input_type.upper()}-JSON", + "Valid Content", + str(STATUS_CONFIGS.PASSED), + ) + except json.JSONDecodeError as e: + return ( + f"{input_type.upper()}-JSON", + f"Invalid JSON content: {e}", + str(STATUS_CONFIGS.FAILED), + ) + except Exception as e: + return ( + f"{input_type.upper()}-JSON", + f"Unexpected error during JSON check: {e}", + str(STATUS_CONFIGS.FAILED), + ) + + def check_h5(file_path): + self.logger.debug(f"Checking HDF5 file: {file_path}") + error_details = [] + + try: + loader = Hdf5DataLoader() + loader.load(file_path) + content = next( + ( + x + for x in [ + loader.values, + loader.keys, + loader.inputs, + loader.features, + ] + if x is not None + ), + None, + ) + + self.logger.debug(f"Content: {content}") + if content is None or (hasattr(content, "size") and content.size == 0): + return ( + f"{input_type.upper()}-HDF5", + "Empty content", + str(STATUS_CONFIGS.FAILED), + ) + + content_array = np.array(content) + if np.isnan(content_array).any(): + nan_indices = np.argwhere(np.isnan(content_array)) + for index in nan_indices: + self.logger.error(f"H5 content invalud value at index: {index}") + error_details.append(f"NaN detected at index: {tuple(index)}") + + if error_details: + return ( + f"{input_type.upper()}-HDF5", + f"Invalid values found in HDF5 content: {error_details}", + str(STATUS_CONFIGS.FAILED), + ) + + return ( + f"{input_type.upper()}-HDF5", + "Valid Content", + str(STATUS_CONFIGS.PASSED), + ) + except Exception as e: + return ( + f"{input_type.upper()}-HDF5", + f"Invalid HDF5 content: {e}", + str(STATUS_CONFIGS.FAILED), + ) + + if not Path(file_path).exists(): + raise FileNotFoundError(f"File {file_path} does not exist.") + + file_extension = Path(file_path).suffix.lower() + if file_extension == ".json": + return check_json(file_path) + elif file_extension == ".csv": + return self._check_csv(file_path, input_type=input_type) + elif file_extension == ".h5": + return check_h5(file_path) + else: + raise ValueError( + f"Unsupported file type: {file_extension}. Supported types are JSON, CSV, and HDF5." + ) @throw_ersilia_exception() - def check_example_input(self, output, run_model, run_example): + def check_example_input(self, run_model, run_example): """ Check if the model can run with example input without error. @@ -1140,25 +1573,32 @@ def check_example_input(self, output, run_model, run_example): run_example : callable Function to generate example input. """ - input_samples = run_example( + self.logger.debug("Checking model with example input...") + output = Options.OUTPUT_CSV.value + input = run_example( n_samples=Options.NUM_SAMPLES.value, file_name=None, simple=True, - try_predefined=False + try_predefined=True, ) + input = json.dumps([input["input"] for input in input]) - self.logger.debug("Testing model on input of 5 smiles given by 'example' command") - - result = run_model( - input=input_samples, - output=output, - batch=100 + self.logger.debug( + "Testing model on input of 5 smiles given by 'example' command" ) - if input_samples: - self.example_input = True + run_model(inputs=input, output=output, batch=100) + + csv_content, _completed_status = input, [] + if csv_content: + _completed_status.append( + (Checks.PREDEFINED_EXAMPLE.value, str(STATUS_CONFIGS.PASSED)) + ) else: - self._print_output(result, output) + _completed_status.append( + (Checks.PREDEFINED_EXAMPLE.value, str(STATUS_CONFIGS.FAILED)) + ) + return _completed_status @throw_ersilia_exception() def check_consistent_output(self, run_example, run_model): @@ -1172,22 +1612,23 @@ def check_consistent_output(self, run_example, run_model): run_model : callable Function to run the model. """ + self.logger.debug("Confirming model produces consistent output...") + def compute_rmse(y_true, y_pred): - return sum((yt - yp) ** 2 for yt, yp in zip(y_true, y_pred)) ** 0.5 / len(y_true) + return sum((yt - yp) ** 2 for yt, yp in zip(y_true, y_pred)) ** 0.5 / len( + y_true + ) def _compare_output_strings(output1, output2): - if output1 is None and output2 is None: - return 100 - else: - return fuzz.ratio(output1, output2) + return fuzz.ratio(output1, output2) def validate_output(output1, output2): + if self._is_invalid_value(output1) or self._is_invalid_value(output2): + raise texc.InconsistentOutputs(self.model_id) + if not isinstance(output1, type(output2)): raise texc.InconsistentOutputTypes(self.model_id) - if output1 is None: - return - if isinstance(output1, (float, int)): rmse = compute_rmse([output1], [output2]) if rmse > 0.1: @@ -1203,6 +1644,7 @@ def validate_output(output1, output2): raise texc.InconsistentOutputs(self.model_id) rho, _ = spearmanr(output1, output2) + if rho < 0.5: raise texc.InconsistentOutputs(self.model_id) @@ -1214,7 +1656,7 @@ def read_csv(file_path): absolute_path = os.path.abspath(file_path) if not os.path.exists(absolute_path): raise FileNotFoundError(f"File not found: {absolute_path}") - with open(absolute_path, mode='r') as csv_file: + with open(absolute_path, mode="r") as csv_file: reader = csv.DictReader(csv_file) return [row for row in reader] @@ -1223,34 +1665,62 @@ def read_csv(file_path): self.logger.debug("Confirming model produces consistent output...") - input_samples = run_example( + input = run_example( n_samples=Options.NUM_SAMPLES.value, file_name=None, simple=True, - try_predefined=False - ) - run_model( - input=input_samples, - output=output1_path, - batch=100 + try_predefined=False, ) - run_model( - input=input_samples, - output=output2_path, - batch=100 - ) - - data1 = read_csv(output1_path) - data2 = read_csv(output1_path) - - for res1, res2 in zip(data1, data2): - for key in res1: - if key in res2: - validate_output(res1[key], res2[key]) - else: - raise KeyError(f"Key '{key}' not found in second result.") + input = json.dumps([input["input"] for input in input]) + + run_model(inputs=input, output=output1_path, batch=100) + run_model(inputs=input, output=output2_path, batch=100) + + check_status_one = self._check_csv(output1_path) + check_status_two = self._check_csv(output2_path) + _completed_status = [] + if check_status_one[-1] == str(STATUS_CONFIGS.FAILED) or check_status_two[ + -1 + ] == str(STATUS_CONFIGS.FAILED): + self.logger.error("Model output has content problem") + _completed_status.append( + ( + Checks.MODEL_CONSISTENCY.value, + check_status_one[1], + str(STATUS_CONFIGS.FAILED), + ) + ) + return _completed_status + else: + data1 = read_csv(output1_path) + data2 = read_csv(output1_path) + try: + for res1, res2 in zip(data1, data2): + for key in res1: + if key in res2: + validate_output(res1[key], res2[key]) + else: + raise KeyError(f"Key '{key}' not found in second result.") + self.logger.info("Model output is consistent") + _completed_status.append( + ( + Checks.MODEL_CONSISTENCY.value, + Checks.CONSISTENCY.value, + str(STATUS_CONFIGS.PASSED), + ) + ) + except: + self.logger.info("incons") + return _completed_status.append( + ( + Checks.MODEL_CONSISTENCY.value, + Checks.INCONSISTENCY.value, + str(STATUS_CONFIGS.FAILED), + ) + ) + self.logger.error(f"Completed status: {_completed_status}") + return _completed_status - self.consistent_output = True class RunnerService: """ @@ -1268,19 +1738,27 @@ class RunnerService: Instance of CheckService for performing various checks on the model. setup_service : SetupService Instance of SetupService for setting up the environment and fetching the model repository. - model_path : str - Path to the model. level : str Level of checks to perform. dir : str Directory where the model repository is located. - remote : bool - Flag indicating whether to fetch the repository from a remote source. - inspect : bool - Flag indicating whether to perform inspection checks. - remove : bool - Flag indicating whether to remove the model directory after tests. - inspecter : InspectService + model_path : Callable + Callable to get the model path. + from_github : bool + Flag indicating whether to fetch the repository from GitHub. + from_s3 : bool + Flag indicating whether to fetch the repository from S3. + from_dockerhub : bool + Flag indicating whether to fetch the repository from DockerHub. + version : str + Version of the model. + shallow : bool + Flag indicating whether to perform shallow checks. + deep : bool + Flag indicating whether to perform deep checks. + as_json : bool + Flag indicating whether to output results as JSON. + inspector : InspectService Instance of InspectService for inspecting models and their configurations. """ @@ -1291,13 +1769,17 @@ def __init__( ios_service: IOService, checkup_service: CheckService, setup_service: SetupService, - model_path: str, level: str, dir: str, - remote: bool, - inspect: bool, - remove: bool, - inspecter: InspectService + model_path: Callable, + from_github: bool, + from_s3: bool, + from_dockerhub: bool, + version: str, + shallow: bool, + deep: bool, + as_json: bool, + inspector: InspectService, ): self.model_id = model_id self.logger = logger @@ -1305,19 +1787,22 @@ def __init__( self.ios_service = ios_service self.console = ios_service.console self.checkup_service = checkup_service - self._model_path = model_path + self.model_path = model_path(self.model_id) self.level = level self.dir = dir - self.remote = remote - self.inspect = inspect - self.remove = remove - self.inspecter = inspecter - self.example = ExampleGenerator( - model_id=self.model_id - ) + self.from_github = from_github + self.from_s3 = from_s3 + self.from_dockerhub = from_dockerhub + self.version = version + self.shallow = shallow + self.deep = deep + self.as_json = as_json + self.report_file = Path.cwd() / f"{self.model_id}-test.json" + self.inspector = inspector + self.example = ExampleGenerator(model_id=self.model_id) self.run_using_bash = False - def run_model(self, input, output: str, batch: int): + def run_model(self, inputs: list, output: str, batch: int): """ Run the model with the given input and output parameters. @@ -1335,34 +1820,52 @@ def run_model(self, input, output: str, batch: int): str The output of the command. """ - if isinstance(input, list): - input = input[0] - self.logger.info("Running model") out = SetupService.run_command( - f"ersilia -v serve {self.model_id} && ersilia -v run -i {input[0]} -o {output} -b {str(batch)}", + f"ersilia -v serve {self.model_id} && ersilia -v run -i '{inputs}' -o {output} -b {str(batch)}", logger=self.logger, ) + self.logger.info(out) return out def fetch(self): """ Fetch the model repository from the specified directory. """ - SetupService.run_command( - " ".join(["ersilia", - "-v", - "fetch", self.model_id, - "--from_dir", self.dir - ]), - logger=self.logger, - ) - def run_exampe( + def _fetch(dir, model_id, logger): + loc = ( + ["--from_dir", self.dir] + if self.from_github or self.from_s3 + else ["--from_dockerhub"] + + (["--version", self.version] if self.version else []) + ) + self.logger.info(f"Fetching model from: {loc}") + out = SetupService.run_command( + " ".join(["ersilia", "-v", "fetch", model_id, *loc]), + logger=logger, + ) + logger.info(f"Fetch out: {out}") + + if os.path.exists(self.model_path): + SetupService.run_command( + " ".join( + [ + "ersilia", + "-v", + "delete", + self.model_id, + ] + ), + logger=self.logger, + ) + _fetch(self.dir, self.model_id, self.logger) + + def run_example( self, n_samples: int, file_name: str = None, simple: bool = True, - try_predefined: bool = False + try_predefined: bool = False, ): """ Generate example input samples for the model. @@ -1387,8 +1890,9 @@ def run_exampe( n_samples=n_samples, file_name=file_name, simple=simple, - try_predefined=try_predefined + try_predefined=try_predefined, ) + @throw_ersilia_exception() def run_bash(self): """ @@ -1399,30 +1903,54 @@ def run_bash(self): RuntimeError If there is an error during the subprocess execution or output comparison. """ + def compute_rmse(y_true, y_pred): - return sum((yt - yp) ** 2 for yt, yp in zip(y_true, y_pred)) ** 0.5 / len(y_true) + return sum((yt - yp) ** 2 for yt, yp in zip(y_true, y_pred)) ** 0.5 / len( + y_true + ) def compare_outputs(bsh_data, ers_data): - columns = set(bsh_data[0].keys()) & set(data[0].keys()) - self.logger.debug(f"Common columns: {columns}") - + _completed_status, _rmse = [], [] + columns = set(bsh_data[0].keys()) & set(ers_data[0].keys()) for column in columns: bv = [row[column] for row in bsh_data] ev = [row[column] for row in ers_data] if all(isinstance(val, (int, float)) for val in bv + ev): rmse = compute_rmse(bv, ev) - self.logger.debug(f"RMSE for {column}: {rmse}") + _rmse.append(rmse) if rmse > 0.1: + rmse_perc = round(rmse * 100, 2) + _completed_status.append( + ( + f"RMSE-{column}", + f"RMSE > 10%{rmse_perc}%", + str(STATUS_CONFIGS.FAILED), + ) + ) raise texc.InconsistentOutputs(self.model_id) + elif all(isinstance(val, str) for val in bv + ev): if not all( self._compare_string_similarity(a, b, 95) - for a, b - in zip(bv, ev) + for a, b in zip(bv, ev) ): + _completed_status.append( + ("String Similarity", "< 95%", str(STATUS_CONFIGS.FAILED)) + ) raise texc.InconsistentOutputs(self.model_id) + _completed_status.append( + ("String Similarity", "> 95%", str(STATUS_CONFIGS.PASSED)) + ) + + rmse = sum(_rmse) / len(_rmse) if _rmse else 0 + rmse_perc = round(rmse * 100, 2) + _completed_status.append( + ("RMSE-MEAN", f"RMSE < 10% | {rmse_perc}%", str(STATUS_CONFIGS.PASSED)) + ) + + return _completed_status def read_csv(path, flag=False): try: @@ -1431,7 +1959,7 @@ def read_csv(path, flag=False): if not lines: self.logger.error(f"File at {path} is empty.") - return [] + return [], "File is empty" headers = lines[0].strip().split(",") if flag: @@ -1444,25 +1972,37 @@ def read_csv(path, flag=False): values = line.strip().split(",") values = values[2:] if flag else values - def infer_type(value): + def is_invalid_value(value): try: - return int(value) + int(value) + return False except ValueError: - try: - return float(value) - except ValueError: - return value + pass - _values = [infer_type(x) for x in values] + try: + float(value) + return False + except ValueError: + pass + + if isinstance(value, str): + return False + + return True + + try: + _values = [ + x if not is_invalid_value(x) else None for x in values + ] + except ValueError as e: + return [], f"Invalid value detected in CSV file: {e}" data.append(dict(zip(headers, _values))) - return data - except Exception as e: - raise RuntimeError( - f"Failed to read CSV from {path}." - ) from e + return data, None + except Exception as e: + raise RuntimeError(f"Failed to read CSV from {path}.") from e def run_subprocess(command, env_vars=None): try: @@ -1473,42 +2013,28 @@ def run_subprocess(command, env_vars=None): check=True, env=env_vars, ) - self.logger.debug( - f"Subprocess output: {result.stdout}" - ) + self.logger.debug(f"Subprocess output: {result.stdout}") return result.stdout except subprocess.CalledProcessError as e: - raise RuntimeError( - "Subprocess execution failed." - ) from e + raise RuntimeError("Subprocess execution failed.") from e with tempfile.TemporaryDirectory() as temp_dir: - - model_path = os.path.join(self.dir) + model_path = os.path.join(self.dir) + output_path = os.path.join(temp_dir, "ersilia_output.csv") + output_log_path = os.path.join(temp_dir, "output.txt") + error_log_path = os.path.join(temp_dir, "error.txt") + input_file_path = os.path.join(temp_dir, "example_file.csv") temp_script_path = os.path.join(temp_dir, "script.sh") bash_output_path = os.path.join(temp_dir, "bash_output.csv") - output_path = os.path.join(temp_dir, "ersilia_output.csv") - output_log_path = os.path.join(temp_dir, "output.txt") - error_log_path = os.path.join(temp_dir, "error.txt") - input = self.run_exampe( + self.run_example( n_samples=Options.NUM_SAMPLES.value, - file_name=None, + file_name=input_file_path, simple=True, - try_predefined=False + try_predefined=False, ) - ex_file = os.path.join(temp_dir, "example_file.csv") - - with open(ex_file, "w") as f: - f.write("smiles\n" + "\n".join(map(str, input))) - - run_sh_path = os.path.join( - model_path, - "model", - "framework", - RUN_FILE - ) + run_sh_path = os.path.join(model_path, "model", "framework", RUN_FILE) if not os.path.exists(run_sh_path): self.logger.warning( f"{RUN_FILE} not found at {run_sh_path}. Skipping bash run." @@ -1516,10 +2042,10 @@ def run_subprocess(command, env_vars=None): return bash_script = f""" - source {self.conda_prefix(self.is_base())}/etc/profile.d/conda.sh + source {self._conda_prefix(self._is_base())}/etc/profile.d/conda.sh conda activate {self.model_id} cd {os.path.dirname(run_sh_path)} - bash run.sh . {ex_file} {bash_output_path} > {output_log_path} 2> {error_log_path} + bash run.sh . {input_file_path} {bash_output_path} > {output_log_path} 2> {error_log_path} conda deactivate """ @@ -1527,61 +2053,47 @@ def run_subprocess(command, env_vars=None): script_file.write(bash_script) self.logger.debug(f"Running bash script: {temp_script_path}") - run_subprocess(["bash", temp_script_path]) - - bsh_data = read_csv(bash_output_path) - self.logger.info(f"Bash Data:{bsh_data}") + out = run_subprocess(["bash", temp_script_path]) + self.logger.info(f"Bash script subprocess output: {out}") + bsh_data, _ = read_csv(bash_output_path) self.logger.debug("Serving the model after run.sh") run_subprocess( - ["ersilia", "-v", - "serve", self.model_id, + [ + "ersilia", + "-v", + "serve", + self.model_id, ] ) - self.logger.debug( - "Running model for bash data consistency checking" - ) + self.logger.debug("Running model for bash data consistency checking") run_subprocess( - ["ersilia", "-v", - "run", - "-i", ex_file, - "-o", output_path - ] + ["ersilia", "-v", "run", "-i", input_file_path, "-o", output_path] ) - data = read_csv(output_path, flag=True) - - compare_outputs(bsh_data, data) - - self.run_using_bash = True + ers_data, _ = read_csv(output_path, flag=True) + check_status = self.checkup_service._check_csv( + output_path, input_type="csv" + ) + if check_status[-1] == str(STATUS_CONFIGS.FAILED): + return [ + ( + ( + Checks.RUN_BASH.value, + check_status[1], + str(STATUS_CONFIGS.FAILED), + ) + ) + ] + status = compare_outputs(bsh_data, ers_data) + return status @staticmethod - def default_env(): - """ - Get the default environment. - - Returns - ------- - str - The default environment. - """ + def _default_env(): if "CONDA_DEFAULT_ENV" in os.environ: return os.environ["CONDA_DEFAULT_ENV"] return None @staticmethod - def conda_prefix(is_base): - """ - Get the conda prefix. - - Parameters - ---------- - is_base : bool - Whether it is the base environment. - - Returns - ------- - str - The conda prefix. - """ + def _conda_prefix(is_base): o = run_command_check_output("which conda").rstrip() if o: o = os.path.abspath(os.path.join(o, "..", "..")) @@ -1593,268 +2105,254 @@ def conda_prefix(is_base): o = run_command_check_output("echo $CONDA_PREFIX_1").rstrip() return o - def is_base(self): - """ - Check if the current environment is the base environment. - - Returns - ------- - bool - True if it is the base environment, False otherwise. - """ - default_env = self.default_env() + def _is_base(self): + default_env = self._default_env() self.logger.debug(f"Default environment: {default_env}") return default_env == "base" - def _compare_string_similarity( - self, - str1, - str2, - threshold - ): + def _compare_string_similarity(self, str1, str2, threshold): similarity = fuzz.ratio(str1, str2) return similarity >= threshold - - def make_output(self, elapsed_time: float): - """ - Generate the final output table with the test results. - - Parameters - ---------- - elapsed_time : float - Time elapsed during the test run. - """ - results = TestResult.generate_results( - self.checkup_service, - elapsed_time, - self.run_using_bash - ) - data = [(key, str(value)) for key, value in results.items()] - self.ios_service._generate_table( - **TABLE_CONFIGS[TableType.FINAL_RUN_SUMMARY].__dict__, - rows=data - ) - - def run(self, output_file: str = Options.OUTPUT_CSV.value): + def run(self): """ Run the model tests and checks. - Parameters - ---------- - output_file : str, optional - Path to the output file for storing the test results. - Raises ------ ImportError If required packages are missing. """ - if not output_file: - output_file = os.path.join( - self._model_path(self.model_id), - Options.OUTPUT_CSV.value + results = [] + try: + if self.from_dockerhub: + self.setup_service.from_github = True + + self.setup_service.get_model() + basic_results = self._perform_basic_checks() + results.extend(basic_results) + if self.shallow: + shallow_results = self._perform_shallow_checks() + results.extend(shallow_results) + + if self.deep: + shallow_results = self._perform_shallow_checks() + deep_results = self._perform_deep_checks() + results.extend(shallow_results) + results.append(deep_results) + + except Exception as error: + tb = traceback.format_exc() + exp = { + "exception": str(error), + "traceback": tb, + } + results.append(exp) + echo(f"An error occurred: {error}\nTraceback:\n{tb}") + + finally: + self.ios_service.update_metadata(results) + if self.as_json: + self.ios_service.collect_and_save_json(results, self.report_file) + echo("Run process completed.") + + def _perform_basic_checks(self): + results = [] + + self.checkup_service.check_information() + results.append( + self._generate_table_from_check( + TableType.MODEL_INFORMATION_CHECKS, self.ios_service.check_results ) + ) - start_time = time.time() + self.ios_service.check_results.clear() - try: - self._perform_checks(output_file) - self._log_directory_sizes() - self._perform_inspect() - if self.level == Options.LEVEL_DEEP.value: - self._perform_deep_checks(output_file) + self.checkup_service.check_files() + results.append( + self._generate_table_from_check( + TableType.MODEL_FILE_CHECKS, self.ios_service.check_results + ) + ) - elapsed_time = time.time() - start_time - self.make_output(elapsed_time) - self._clear_folders() + results.append(self._log_directory_sizes()) + results.append(self._docker_yml_check()) - except Exception as e: - click.echo( - f"An error occurred: {e}" - ) - finally: - click.echo( - "Run process finished successfully." - ) + return results - def transform_key(self, value): - """ - Transform the test result key to a string representation. + @show_loader(text="Performing shallow checks", color="cyan") + def _perform_shallow_checks(self): + self.fetch() + model_output = self.checkup_service.check_model_output_content( + self.run_example, self.run_model + ) + results, _validations = [], [] - Parameters - ---------- - value : any - The test result value. + if self.from_github or self.from_s3: + _validations.extend(self._log_env_sizes()) - Returns - ------- - str - The string representation of the test result. - """ - if value is True: - return str(STATUS_CONFIGS.PASSED) - elif value is False: - return str(STATUS_CONFIGS.FAILED) - return value - - def _perform_inspect(self): - if self.inspect: - out = self.inspecter.run() - out = { - " ".join(word.capitalize() - for word in k.split("_")): self.transform_key(v) - for k, v in out.items() - } + if self.from_dockerhub: + message, docker_size = self.ios_service.calculate_image_size( + tag=self.version if self.version else "latest" + ) + _validations.append((Checks.IMAGE_SIZE.value, message, docker_size)) - data = [(key, value) for key, value in out.items()] + _validations.extend(self._run_single_and_example_input_checks()) - self.ios_service._generate_table( - **TABLE_CONFIGS[TableType.INSPECT_SUMMARY].__dict__, - rows=data, - large_table=True, - merge=True + results.append( + self._generate_table_from_check( + TableType.SHALLOW_CHECK_SUMMARY, _validations, large=True ) + ) - def _perform_checks(self, output_file): - self.checkup_service.check_information(output_file) - self._generate_table( - **TABLE_CONFIGS[TableType.MODEL_INFORMATION_CHECKS].__dict__, - rows=self.ios_service.check_results + bash_results = self.run_bash() + results.append( + self._generate_table_from_check(TableType.CONSISTENCY_BASH, bash_results) ) - self.ios_service.check_results.clear() - self.checkup_service.check_files() - self._generate_table( - **TABLE_CONFIGS[TableType.MODEL_FILE_CHECKS].__dict__, - rows=self.ios_service.check_results + results.append( + self._generate_table_from_check(TableType.MODEL_OUTPUT, model_output) ) - def _log_directory_sizes(self): - dir_size, env_size = self.ios_service.get_directories_sizes() - self._generate_table( - **TABLE_CONFIGS[TableType.MODEL_DIRECTORY_SIZES].__dict__, - rows=[(f"{dir_size:.2f}MB", f"{env_size:.2f}MB")] + return results + + @show_loader(text="Performing deep checks", color="cyan") + def _perform_deep_checks(self): + performance_data = self.inspector.run(["computational_performance_tracking"]) + self.logger.info(f"Performance data: {performance_data}") + return self._generate_table_from_check( + TableType.COMPUTATIONAL_PERFORMANCE, performance_data, large=True ) - def _perform_deep_checks(self, output_file): - self.checkup_service.check_single_input( - output_file, - self.run_model, - self.run_exampe + def _docker_yml_check(self): + docker_check_data = self.inspector.run(["docker_check"]) + return self._generate_table_from_check( + TableType.DEPENDECY_CHECK, docker_check_data, large=True ) - self._generate_table( - **TABLE_CONFIGS[TableType.RUNNER_CHECKUP_STATUS].__dict__, - rows = [ - ["Fetch", str(STATUS_CONFIGS.PASSED)], - ["Serve", str(STATUS_CONFIGS.PASSED)], - ["Run", str(STATUS_CONFIGS.PASSED)] - ] + + def _log_env_sizes(self): + env_size = self.ios_service.get_env_sizes() + return [(Checks.ENV_SIZE.value, Checks.SIZE_CACL_SUCCESS.value, env_size)] + + def _log_directory_sizes(self): + directory_size = self.ios_service.get_directories_sizes() + return self._generate_table_from_check( + TableType.MODEL_DIRECTORY_SIZES, [(Checks.DIR_SIZE.value, directory_size)] ) - self.checkup_service.check_example_input( - output_file, - self.run_model, - self.run_exampe + + def _run_single_and_example_input_checks(self): + results = [] + results.extend( + self.checkup_service.check_example_input(self.run_model, self.run_example) ) - self.checkup_service.check_consistent_output( - self.run_exampe, - self.run_model + results.extend( + self.checkup_service.check_consistent_output( + self.run_example, self.run_model + ) ) - self.run_bash() + return results - def _generate_table(self, title, headers, rows): - self.ios_service._generate_table( - title=title, - headers=headers, - rows=rows + def _generate_table_from_check(self, table_type, rows, large=False): + config = TABLE_CONFIGS[table_type].__dict__ + return self.ios_service._generate_table( + title=config["title"], + headers=config["headers"], + rows=rows, + large_table=large, ) - def _clear_folders(self): - if self.remove: - SetupService.run_command( - f"rm -rf {self.dir}", - logger=self.logger, - ) + class ModelTester(ErsiliaBase): """ Class to handle model testing. Initializes the model tester services and runs the tests. + Parameters ---------- - model_id : str + model : str The ID of the model. level : str The level of testing. - dir : str + from_dir : str The directory for the model. - inspect : bool - Whether to inspect the model. - remote : bool - Whether to fetch the model from a remote source. - remove : bool - Whether to remove the model after testing. + from_github : bool + Flag indicating whether to fetch the repository from GitHub. + from_dockerhub : bool + Flag indicating whether to fetch the repository from DockerHub. + from_s3 : bool + Flag indicating whether to fetch the repository from S3. + version : str + Version of the model. + shallow : bool + Flag indicating whether to perform shallow checks. + deep : bool + Flag indicating whether to perform deep checks. + as_json : bool + Flag indicating whether to output results as JSON. """ + def __init__( - self, - model_id, - level, - dir, - inspect, - remote, - remove - ): - ErsiliaBase.__init__( - self, - config_json=None, - credentials_json=None - ) - self.model_id = model_id + self, + model, + level, + from_dir, + from_github, + from_dockerhub, + from_s3, + version, + shallow, + deep, + as_json, + ): + ErsiliaBase.__init__(self, config_json=None, credentials_json=None) + self.model_id = model self.level = level - self.dir = dir or self.model_id - self.inspect = inspect - self.remote = remote - self.remove = remove + self.from_dir = from_dir + self.model_dir = os.path.join(EOS_TMP, self.model_id) + self.dir = from_dir or self.model_dir + self.from_github = from_github + self.from_dockerhub = from_dockerhub + self.from_s3 = from_s3 + self.version = version + self.shallow = shallow + self.deep = deep + self.as_json = as_json self._check_pedendency() self.setup_service = SetupService( self.model_id, self.dir, + self.from_github, + self.from_s3, self.logger, - self.remote - ) - self.ios = IOService( - self.logger, - self._dest_dir, - self._model_path, - self._get_bundle_location, - self._get_bentoml_location, - self.model_id, - self.dir ) + self.ios = IOService(self.logger, self.model_id, self.dir) self.checks = CheckService( self.logger, self.model_id, - self._dest_dir, self.dir, + self.from_github, + self.from_s3, self.ios, ) - self.inspecter = InspectService( - dir=self.dir if not self.remote else None, - model=self.model_id, - remote=self.remote - ) + self.inspector = InspectService(dir=self.dir, model=self.model_id, remote=True) self.runner = RunnerService( self.model_id, self.logger, self.ios, self.checks, self.setup_service, - self._model_path, self.level, self.dir, - self.remote, - self.inspect, - self.remove, - self.inspecter + self._model_path, + self.from_github, + self.from_s3, + self.from_dockerhub, + self.version, + self.shallow, + self.deep, + self.as_json, + self.inspector, ) + def _check_pedendency(self): if MISSING_PACKAGES: raise ImportError( @@ -1862,23 +2360,8 @@ def _check_pedendency(self): "Please install test extras with 'pip install ersilia[test]'." ) - def setup(self): - """ - Set up the model tester. - """ - self.logger.debug(f"Running conda setup for {self.model_id}") - self.setup_service.fetch_repo() # for remote option - self.logger.debug(f"Fetching model {self.model_id} from local dir: {self.dir}") - self.runner.fetch() - self.setup_service.check_conda_env() - - def run(self, output_file=None): + def run(self): """ Run the model tester. - - Parameters - ---------- - output_file : str, optional - The output file. """ - self.runner.run(output_file) + self.runner.run() diff --git a/ersilia/utils/spinner.py b/ersilia/utils/spinner.py new file mode 100644 index 000000000..4776e6918 --- /dev/null +++ b/ersilia/utils/spinner.py @@ -0,0 +1,104 @@ +import itertools +import sys +import threading +import time + +try: + import emoji +except ImportError: + emoji = None +import os + +if not os.environ.get("PYTHONIOENCODING"): + os.environ["PYTHONIOENCODING"] = "UTF-8" + sys.stdout.reconfigure(encoding="utf-8") + + +class Spinner: + """ + A colorful and animated loader for terminal applications. + """ + + def __init__(self, text="Loading...", spinner=None, color="cyan"): + self.text = text + self.spinner = spinner or ["⠋", "⠙", "⠸", "⠴", "⠦", "⠇"] + self.spinner = itertools.cycle(self.spinner) + self.color = self._get_color_code(color) + self.text_color = "\033[37m" + self.reset_color = "\033[0m" + self.running = False + self.thread = None + self.lock = threading.Lock() + self.is_paused = False + + def _get_color_code(self, color): + colors = { + "black": "\033[30m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + } + return colors.get(color.lower(), "\033[37m") + + def _start(self): + self.running = True + self.thread = threading.Thread(target=self._animate, daemon=True) + self.thread.start() + + def _stop(self, success=True): + with self.lock: + self.running = False + if self.thread is not None: + self.thread.join() + sys.stdout.write("\r") + sys.stdout.flush() + status_icon = emoji.emojize("✅") if success else emoji.emojize("❌") + sys.stdout.write( + f"{self.color}{status_icon}{self.reset_color} {self.text_color}{self.text} Done!{self.reset_color}\n" + ) + sys.stdout.flush() + + def _pause(self): + """Pause the spinner for blocking operations like prompts.""" + with self.lock: + self.is_paused = True + # Clear the spinner line before pausing + sys.stdout.write("\r") + sys.stdout.flush() + + def _resume(self): + """Resume the spinner after a pause.""" + with self.lock: + self.is_paused = False + + def _animate(self): + while self.running: + if not self.is_paused: + with self.lock: + sys.stdout.write( + f"\r{self.color}{next(self.spinner)}{self.reset_color} {self.text_color}{self.text}{self.reset_color} " + ) + sys.stdout.flush() + time.sleep(0.1) + + +def show_loader(text="Loading...", color="cyan"): + def decorator(func): + def wrapper(*args, **kwargs): + loader = Spinner(text=text, color=color) + loader._start() + try: + result = func(*args, **kwargs) + loader._stop(success=True) + return result + except Exception as e: + loader._stop(success=False) + raise e + + return wrapper + + return decorator