Skip to content

Commit

Permalink
restructure config
Browse files Browse the repository at this point in the history
  • Loading branch information
sixianyi0721 committed Jan 15, 2025
1 parent cf1a568 commit 697ce9c
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 65 deletions.
2 changes: 1 addition & 1 deletion llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def register_model(
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
Expand Down
30 changes: 26 additions & 4 deletions llama_stack/providers/tests/agents/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..test_config_helper import (
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES

Expand Down Expand Up @@ -82,16 +86,33 @@ def pytest_configure(config):


def pytest_generate_tests(metafunc):
shield_id = metafunc.config.getoption("--safety-shield")
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
(
config_override_inference_models,
config_override_safety_shield,
custom_provider_fixtures,
) = (None, None, None)
if test_config is not None:
config_override_inference_models = test_config.agent.fixtures.inference_models
config_override_safety_shield = test_config.agent.fixtures.safety_shield
custom_provider_fixtures = get_provider_fixtures_from_config(
test_config.agent.fixtures.provider_fixtures, DEFAULT_PROVIDER_COMBINATIONS
)

shield_id = config_override_safety_shield or metafunc.config.getoption(
"--safety-shield"
)
inference_model = config_override_inference_models or [
metafunc.config.getoption("--inference-model")
]
if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize(
"safety_shield",
[pytest.param(shield_id, id="")],
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = set({inference_model})
models = set(inference_model)
if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model)

Expand All @@ -109,7 +130,8 @@ def pytest_generate_tests(metafunc):
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
custom_provider_fixtures
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("agents_stack", combinations, indirect=True)
83 changes: 59 additions & 24 deletions llama_stack/providers/tests/ci_test_config.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,59 @@
tests:
- path: inference/test_vision_inference.py
functions:
- test_vision_chat_completion_streaming
- test_vision_chat_completion_non_streaming

- path: inference/test_text_inference.py
functions:
- test_structured_output
- test_chat_completion_streaming
- test_chat_completion_non_streaming
- test_chat_completion_with_tool_calling
- test_chat_completion_with_tool_calling_streaming

inference_fixtures:
- ollama
- fireworks
- together
- tgi
- vllm_remote

test_models:
text: meta-llama/Llama-3.1-8B-Instruct
vision: meta-llama/Llama-3.2-11B-Vision-Instruct
inference:
tests:
- inference/test_vision_inference.py::test_vision_chat_completion_streaming
- inference/test_vision_inference.py::test_vision_chat_completion_non_streaming
- inference/test_text_inference.py::test_structured_output
- inference/test_text_inference.py::test_chat_completion_streaming
- inference/test_text_inference.py::test_chat_completion_non_streaming
- inference/test_text_inference.py::test_chat_completion_with_tool_calling
- inference/test_text_inference.py::test_chat_completion_with_tool_calling_streaming

fixtures:
provider_fixtures:
- inference: ollama
- default_fixture_param_id: fireworks
- inference: together
# - inference: tgi
# - inference: vllm_remote
inference_models:
- meta-llama/Llama-3.1-8B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct

safety_shield: ~
embedding_model: ~


agent:
tests:
- agents/test_agents.py::test_agent_turns_with_safety
- agents/test_agents.py::test_rag_agent

fixtures:
provider_fixtures:
- default_fixture_param_id: ollama
- default_fixture_param_id: together
- default_fixture_param_id: fireworks

safety_shield: ~
embedding_model: ~

inference_models:
- meta-llama/Llama-3.2-1B-Instruct


memory:
tests:
- memory/test_memory.py::test_query_documents

fixtures:
provider_fixtures:
- default_fixture_param_id: ollama
- inference: sentence_transformers
memory: faiss
- default_fixture_param_id: chroma

inference_models:
- meta-llama/Llama-3.2-1B-Instruct

safety_shield: ~
embedding_model: ~
42 changes: 18 additions & 24 deletions llama_stack/providers/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# the root directory of this source tree.

import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional

import pytest

import yaml
from dotenv import load_dotenv
from pydantic import BaseModel
from termcolor import colored
Expand All @@ -20,6 +20,8 @@

from .env import get_env_or_fail

from .test_config_helper import try_load_config_file_cached


class ProviderFixture(BaseModel):
providers: List[Provider]
Expand Down Expand Up @@ -180,34 +182,26 @@ def pytest_itemcollected(item):


def pytest_collection_modifyitems(session, config, items):
if config.getoption("--config") is None:
test_config = try_load_config_file_cached(config.getoption("--config"))
if test_config is None:
return
file_name = config.getoption("--config")
config_file_path = Path(__file__).parent / file_name
if not config_file_path.exists():
raise ValueError(
f"Test config {file_name} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory."
)

required_tests = dict()
inference_providers = set()
with open(config_file_path, "r") as config_file:
test_config = yaml.safe_load(config_file)
for test in test_config["tests"]:
required_tests[Path(__file__).parent / test["path"]] = set(
test["functions"]
)
inference_providers = set(test_config["inference_fixtures"])
required_tests = defaultdict(set)
test_configs = [test_config.inference, test_config.memory, test_config.agent]
for test_config in test_configs:
for test in test_config.tests:
arr = test.split("::")
if len(arr) != 2:
raise ValueError(f"Invalid format for test name {test}")
test_path, func_name = arr
required_tests[Path(__file__).parent / test_path].add(func_name)

new_items, deselected_items = [], []
for item in items:
if item.fspath in required_tests:
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
inference = item.callspec.params.get("inference_stack")
if inference in inference_providers:
new_items.append(item)
continue
func_name = getattr(item, "originalname", item.name)
if func_name in required_tests[item.fspath]:
new_items.append(item)
continue
deselected_items.append(item)

items[:] = new_items
Expand Down
36 changes: 25 additions & 11 deletions llama_stack/providers/tests/inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from ..conftest import get_provider_fixture_overrides

from ..test_config_helper import try_load_config_file_cached
from .fixtures import INFERENCE_FIXTURES


Expand Down Expand Up @@ -43,29 +43,43 @@ def pytest_configure(config):


def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
cls_name = metafunc.cls.__name__
if test_config is not None:
params = []
for model in test_config.inference.fixtures.inference_models:
if ("Vision" in cls_name and "Vision" in model) or (
"Vision" not in cls_name and "Vision" not in model
):
params.append(pytest.param(model, id=model))
else:
cls_name = metafunc.cls.__name__
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
model = metafunc.config.getoption("--inference-model")
if model:
params = [pytest.param(model, id="")]
else:
params = MODEL_PARAMS

if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize(
"inference_model",
params,
indirect=True,
)
if "inference_stack" in metafunc.fixturenames:
fixtures = INFERENCE_FIXTURES
if filtered_stacks := get_provider_fixture_overrides(
if test_config is not None:
fixtures = [
(f.get("inference") or f.get("default_fixture_param_id"))
for f in test_config.inference.fixtures.provider_fixtures
]
elif filtered_stacks := get_provider_fixture_overrides(
metafunc.config,
{
"inference": INFERENCE_FIXTURES,
},
):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
else:
fixtures = INFERENCE_FIXTURES
metafunc.parametrize("inference_stack", fixtures, indirect=True)
1 change: 1 addition & 0 deletions llama_stack/providers/tests/inference/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data,
models=[
ModelInput(
provider_id=inference_fixture.providers[0].provider_id,
model_id=inference_model,
model_type=model_type,
metadata=metadata,
Expand Down
16 changes: 15 additions & 1 deletion llama_stack/providers/tests/memory/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from ..conftest import get_provider_fixture_overrides

from ..inference.fixtures import INFERENCE_FIXTURES
from ..test_config_helper import (
get_provider_fixtures_from_config,
try_load_config_file_cached,
)
from .fixtures import MEMORY_FIXTURES


Expand Down Expand Up @@ -65,6 +69,15 @@ def pytest_configure(config):


def pytest_generate_tests(metafunc):
test_config = try_load_config_file_cached(metafunc.config.getoption("config"))
provider_fixtures_config = (
test_config.memory.fixtures.provider_fixtures
if test_config is not None
else None
)
custom_fixtures = get_provider_fixtures_from_config(
provider_fixtures_config, DEFAULT_PROVIDER_COMBINATIONS
)
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if model:
Expand All @@ -80,7 +93,8 @@ def pytest_generate_tests(metafunc):
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
custom_fixtures
or get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("memory_stack", combinations, indirect=True)
Loading

0 comments on commit 697ce9c

Please sign in to comment.