Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
sixianyi0721 committed Jan 16, 2025
1 parent 11b2fdd commit 33de00c
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 94 deletions.
17 changes: 11 additions & 6 deletions llama_stack/providers/tests/agents/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

import pytest

from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..test_config_helper import (
from ..conftest import (
get_provider_fixture_overrides,
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield

# from ..test_config_helper import (
# get_provider_fixtures_from_config,
# try_load_config_file_cached,
# )
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES

Expand Down Expand Up @@ -86,7 +91,7 @@ def pytest_configure(config):


def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
test_config = try_load_config_file_cached(metafunc.config)
(
config_override_inference_models,
config_override_safety_shield,
Expand Down
70 changes: 66 additions & 4 deletions llama_stack/providers/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,90 @@

import os
from collections import defaultdict

from pathlib import Path
from typing import Any, Dict, List, Optional

import pytest
import yaml

from dotenv import load_dotenv
from pydantic import BaseModel
from pydantic import BaseModel, Field
from termcolor import colored

from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.datatypes import RemoteProviderConfig

from .env import get_env_or_fail

from .test_config_helper import try_load_config_file_cached


class ProviderFixture(BaseModel):
providers: List[Provider]
provider_data: Optional[Dict[str, Any]] = None


class Fixtures(BaseModel):
# provider fixtures can be either a mark or a dictionary of api -> providers
provider_fixtures: List[Dict[str, str]] = Field(default_factory=list)
inference_models: List[str] = Field(default_factory=list)
safety_shield: Optional[str] = Field(default_factory=None)
embedding_model: Optional[str] = Field(default_factory=None)


class APITestConfig(BaseModel):
fixtures: Fixtures

# test name format should be <relative_path.py>::<test_name>
tests: List[str] = Field(default_factory=list)


class TestConfig(BaseModel):
inference: APITestConfig
agent: Optional[APITestConfig] = Field(default=None)
memory: Optional[APITestConfig] = Field(default=None)


CONFIG_CACHE = None


def try_load_config_file_cached(config):
config_file = config.getoption("--config")
if config_file is None:
return None
if CONFIG_CACHE is not None:
return CONFIG_CACHE

config_file_path = Path(__file__).parent / config_file
if not config_file_path.exists():
raise ValueError(
f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)
with open(config_file_path, "r") as config_file:
config = yaml.safe_load(config_file)
return TestConfig(**config)


def get_provider_fixtures_from_config(
provider_fixtures_config, default_fixture_combination
):
custom_fixtures = []
selected_default_param_id = set()
for fixture_config in provider_fixtures_config:
if "default_fixture_param_id" in fixture_config:
selected_default_param_id.add(fixture_config["default_fixture_param_id"])
else:
custom_fixtures.append(
pytest.param(fixture_config, id=fixture_config.get("inference") or "")
)

if len(selected_default_param_id) > 0:
for default_fixture in default_fixture_combination:
if default_fixture.id in selected_default_param_id:
custom_fixtures.append(default_fixture)

return custom_fixtures


def remote_stack_fixture() -> ProviderFixture:
if url := os.getenv("REMOTE_STACK_URL", None):
config = RemoteProviderConfig.from_url(url)
Expand Down Expand Up @@ -182,7 +244,7 @@ def pytest_itemcollected(item):


def pytest_collection_modifyitems(session, config, items):
test_config = try_load_config_file_cached(config.getoption("--config"))
test_config = try_load_config_file_cached(config)
if test_config is None:
return

Expand Down
5 changes: 2 additions & 3 deletions llama_stack/providers/tests/inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import pytest

from ..conftest import get_provider_fixture_overrides
from ..test_config_helper import try_load_config_file_cached
from ..conftest import get_provider_fixture_overrides, try_load_config_file_cached
from .fixtures import INFERENCE_FIXTURES


Expand Down Expand Up @@ -43,7 +42,7 @@ def pytest_configure(config):


def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
test_config = try_load_config_file_cached(metafunc.config)
if "inference_model" in metafunc.fixturenames:
cls_name = metafunc.cls.__name__
if test_config is not None:
Expand Down
10 changes: 5 additions & 5 deletions llama_stack/providers/tests/memory/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import pytest

from ..conftest import get_provider_fixture_overrides

from ..inference.fixtures import INFERENCE_FIXTURES
from ..test_config_helper import (
from ..conftest import (
get_provider_fixture_overrides,
get_provider_fixtures_from_config,
try_load_config_file_cached,
)

from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES


Expand Down Expand Up @@ -69,7 +69,7 @@ def pytest_configure(config):


def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
test_config = try_load_config_file_cached(metafunc.config)
provider_fixtures_config = (
test_config.memory.fixtures.provider_fixtures
if test_config is not None and test_config.memory is not None
Expand Down
76 changes: 0 additions & 76 deletions llama_stack/providers/tests/test_config_helper.py

This file was deleted.

0 comments on commit 33de00c

Please sign in to comment.