Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sixianyi0721 committed Jan 15, 2025
1 parent 1ecc3a8 commit 01cfefe
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
11 changes: 8 additions & 3 deletions llama_stack/providers/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from llama_stack.providers.datatypes import RemoteProviderConfig

from .env import get_env_or_fail
from .report import Report

from .test_config_helper import try_load_config_file_cached
from .report import Report


class ProviderFixture(BaseModel):
Expand Down Expand Up @@ -64,8 +64,8 @@ def pytest_configure(config):
key, value = env_var.split("=", 1)
os.environ[key] = value

if config.getoption("--config") is not None:
config.pluginmanager.register(Report(config))
if config.getoption("--output") is not None:
config.pluginmanager.register(Report(config.getoption("--output")))


def pytest_addoption(parser):
Expand All @@ -82,6 +82,11 @@ def pytest_addoption(parser):
action="store",
help="Set test config file (supported format: YAML), e.g. --config=test_config.yml",
)
parser.addoption(
"--output",
action="store",
help="Set output file for test report, e.g. --output=pytest_report.md",
)
"""Add custom command line options"""
parser.addoption(
"--env", action="append", help="Set environment variables, e.g. --env KEY=value"
Expand Down
20 changes: 18 additions & 2 deletions llama_stack/providers/tests/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,18 @@

class Report:

def __init__(self, _config):
def __init__(self, output_path):

valid_file_format = (
output_path.split(".")[1] in ["md", "markdown"]
if len(output_path.split(".")) == 2
else False
)
if not valid_file_format:
raise ValueError(
f"Invalid output file {output_path}. Markdown file is required"
)
self.output_path = output_path
self.test_data = defaultdict(dict)
self.inference_tests = defaultdict(dict)

Expand Down Expand Up @@ -108,6 +119,11 @@ def pytest_sessionfinish(self, session):

rows = []
for model in all_registered_models():
if (
"Instruct" not in model.core_model_id.value
and "Guard" not in model.core_model_id.value
):
continue
row = f"| {model.core_model_id.value} |"
for k in SUPPORTED_MODELS.keys():
if model.core_model_id.value in SUPPORTED_MODELS[k]:
Expand Down Expand Up @@ -149,7 +165,7 @@ def pytest_sessionfinish(self, session):
report.extend(test_table)
report.append("\n")

output_file = Path("pytest_report.md")
output_file = Path(self.output_path)
output_file.write_text("\n".join(report))
print(f"\n Report generated: {output_file.absolute()}")

Expand Down

0 comments on commit 01cfefe

Please sign in to comment.