diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index dcb6e0f3aa..c22d308948 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -6,14 +6,19 @@ import pytest -from ..conftest import get_provider_fixture_overrides -from ..inference.fixtures import INFERENCE_FIXTURES -from ..memory.fixtures import MEMORY_FIXTURES -from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield -from ..test_config_helper import ( +from ..conftest import ( + get_provider_fixture_overrides, get_provider_fixtures_from_config, try_load_config_file_cached, ) +from ..inference.fixtures import INFERENCE_FIXTURES +from ..memory.fixtures import MEMORY_FIXTURES +from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield + +# from ..test_config_helper import ( +# get_provider_fixtures_from_config, +# try_load_config_file_cached, +# ) from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from .fixtures import AGENTS_FIXTURES @@ -86,7 +91,7 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - test_config = try_load_config_file_cached(metafunc.config.getoption("config")) + test_config = try_load_config_file_cached(metafunc.config) ( config_override_inference_models, config_override_safety_shield, diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index c75dc67a59..cf87865be7 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -6,13 +6,15 @@ import os from collections import defaultdict + from pathlib import Path from typing import Any, Dict, List, Optional import pytest +import yaml from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field from termcolor import colored from llama_stack.distribution.datatypes import Provider @@ -20,14 +22,74 @@ from .env import get_env_or_fail -from .test_config_helper import try_load_config_file_cached - class ProviderFixture(BaseModel): providers: List[Provider] provider_data: Optional[Dict[str, Any]] = None +class Fixtures(BaseModel): + # provider fixtures can be either a mark or a dictionary of api -> providers + provider_fixtures: List[Dict[str, str]] = Field(default_factory=list) + inference_models: List[str] = Field(default_factory=list) + safety_shield: Optional[str] = Field(default_factory=None) + embedding_model: Optional[str] = Field(default_factory=None) + + +class APITestConfig(BaseModel): + fixtures: Fixtures + + # test name format should be :: + tests: List[str] = Field(default_factory=list) + + +class TestConfig(BaseModel): + inference: APITestConfig + agent: Optional[APITestConfig] = Field(default=None) + memory: Optional[APITestConfig] = Field(default=None) + + +CONFIG_CACHE = None + + +def try_load_config_file_cached(config): + config_file = config.getoption("--config") + if config_file is None: + return None + if CONFIG_CACHE is not None: + return CONFIG_CACHE + + config_file_path = Path(__file__).parent / config_file + if not config_file_path.exists(): + raise ValueError( + f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." + ) + with open(config_file_path, "r") as config_file: + config = yaml.safe_load(config_file) + return TestConfig(**config) + + +def get_provider_fixtures_from_config( + provider_fixtures_config, default_fixture_combination +): + custom_fixtures = [] + selected_default_param_id = set() + for fixture_config in provider_fixtures_config: + if "default_fixture_param_id" in fixture_config: + selected_default_param_id.add(fixture_config["default_fixture_param_id"]) + else: + custom_fixtures.append( + pytest.param(fixture_config, id=fixture_config.get("inference") or "") + ) + + if len(selected_default_param_id) > 0: + for default_fixture in default_fixture_combination: + if default_fixture.id in selected_default_param_id: + custom_fixtures.append(default_fixture) + + return custom_fixtures + + def remote_stack_fixture() -> ProviderFixture: if url := os.getenv("REMOTE_STACK_URL", None): config = RemoteProviderConfig.from_url(url) @@ -182,7 +244,7 @@ def pytest_itemcollected(item): def pytest_collection_modifyitems(session, config, items): - test_config = try_load_config_file_cached(config.getoption("--config")) + test_config = try_load_config_file_cached(config) if test_config is None: return diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index fca4f7544c..1343459e98 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -6,8 +6,7 @@ import pytest -from ..conftest import get_provider_fixture_overrides -from ..test_config_helper import try_load_config_file_cached +from ..conftest import get_provider_fixture_overrides, try_load_config_file_cached from .fixtures import INFERENCE_FIXTURES @@ -43,7 +42,7 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - test_config = try_load_config_file_cached(metafunc.config.getoption("config")) + test_config = try_load_config_file_cached(metafunc.config) if "inference_model" in metafunc.fixturenames: cls_name = metafunc.cls.__name__ if test_config is not None: diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 72f9bfc551..99fdb7715d 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -6,13 +6,13 @@ import pytest -from ..conftest import get_provider_fixture_overrides - -from ..inference.fixtures import INFERENCE_FIXTURES -from ..test_config_helper import ( +from ..conftest import ( + get_provider_fixture_overrides, get_provider_fixtures_from_config, try_load_config_file_cached, ) + +from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import MEMORY_FIXTURES @@ -69,7 +69,7 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - test_config = try_load_config_file_cached(metafunc.config.getoption("config")) + test_config = try_load_config_file_cached(metafunc.config) provider_fixtures_config = ( test_config.memory.fixtures.provider_fixtures if test_config is not None and test_config.memory is not None diff --git a/llama_stack/providers/tests/test_config_helper.py b/llama_stack/providers/tests/test_config_helper.py deleted file mode 100644 index c86822f0e4..0000000000 --- a/llama_stack/providers/tests/test_config_helper.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, List, Optional - -import pytest -import yaml -from pydantic import BaseModel, Field - - -@dataclass -class APITestConfig(BaseModel): - - class Fixtures(BaseModel): - # provider fixtures can be either a mark or a dictionary of api -> providers - provider_fixtures: List[Dict[str, str]] = Field(default_factory=list) - inference_models: List[str] = Field(default_factory=list) - safety_shield: Optional[str] = Field(default_factory=None) - embedding_model: Optional[str] = Field(default_factory=None) - - fixtures: Fixtures - tests: List[str] = Field(default_factory=list) - - # test name format should be :: - - -class TestConfig(BaseModel): - - inference: APITestConfig - agent: Optional[APITestConfig] = Field(default=None) - memory: Optional[APITestConfig] = Field(default=None) - - -CONFIG_CACHE = None - - -def try_load_config_file_cached(config_file): - if config_file is None: - return None - if CONFIG_CACHE is not None: - return CONFIG_CACHE - - config_file_path = Path(__file__).parent / config_file - if not config_file_path.exists(): - raise ValueError( - f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." - ) - with open(config_file_path, "r") as config_file: - config = yaml.safe_load(config_file) - return TestConfig(**config) - - -def get_provider_fixtures_from_config( - provider_fixtures_config, default_fixture_combination -): - custom_fixtures = [] - selected_default_param_id = set() - for fixture_config in provider_fixtures_config: - if "default_fixture_param_id" in fixture_config: - selected_default_param_id.add(fixture_config["default_fixture_param_id"]) - else: - custom_fixtures.append( - pytest.param(fixture_config, id=fixture_config.get("inference") or "") - ) - - if len(selected_default_param_id) > 0: - for default_fixture in default_fixture_combination: - if default_fixture.id in selected_default_param_id: - custom_fixtures.append(default_fixture) - - return custom_fixtures