diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index a05e08cd7c..78d76e9777 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -55,7 +55,7 @@ class ProviderWithSpec(Provider): # TODO: this code is not very straightforward to follow and needs one more round of refactoring -async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: +async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index eba89e3939..6154432b6e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -37,7 +37,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls_with_routing +from llama_stack.distribution.resolver import resolve_impls from .endpoints import get_all_api_endpoints @@ -276,7 +276,7 @@ def main( app = FastAPI() - impls = asyncio.run(resolve_impls_with_routing(config)) + impls = asyncio.run(resolve_impls(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index fabb245e74..de672b6dc4 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls_with_routing +from llama_stack.distribution.resolver import resolve_impls async def resolve_impls_for_test(api: Api, deps: List[Api] = None): @@ -36,7 +36,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None): providers=chosen, ) run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls_with_routing(run_config) + impls = await resolve_impls(run_config) if "provider_data" in config_dict: provider_id = chosen[api.value][0].provider_id