diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index 901a8c7fbe..4e1161cedb 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -17,13 +17,18 @@ class MetaReferenceInferenceConfig(BaseModel): model: str = Field( - default="Llama3.1-8B-Instruct", + default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ) torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 + # when this is False, we assume that the distributed process group is setup by someone + # outside of this code (e.g., when run inside `torchrun`). that is useful for clients + # (including our testing code) who might be using llama-stack as a library. + create_distributed_process_group: bool = True + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 6696762c9a..7edc279d03 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -18,6 +18,7 @@ ) from .config import MetaReferenceInferenceConfig +from .generation import Llama from .model_parallel import LlamaModelParallelGenerator # there's a single model parallel process running serving the model. for now, @@ -36,8 +37,11 @@ def __init__(self, config: MetaReferenceInferenceConfig) -> None: async def initialize(self) -> None: print(f"Loading model `{self.model.descriptor()}`") - self.generator = LlamaModelParallelGenerator(self.config) - self.generator.start() + if self.config.create_distributed_process_group: + self.generator = LlamaModelParallelGenerator(self.config) + self.generator.start() + else: + self.generator = Llama.build(self.config) async def register_model(self, model: ModelDef) -> None: raise ValueError("Dynamic model registration is not supported") @@ -51,7 +55,8 @@ async def list_models(self) -> List[ModelDef]: ] async def shutdown(self) -> None: - self.generator.stop() + if self.config.create_distributed_process_group: + self.generator.stop() def completion( self, @@ -99,8 +104,9 @@ def chat_completion( f"Model mismatch: {request.model} != {self.model.descriptor()}" ) - if SEMAPHORE.locked(): - raise RuntimeError("Only one concurrent request is supported") + if self.config.create_distributed_process_group: + if SEMAPHORE.locked(): + raise RuntimeError("Only one concurrent request is supported") if request.stream: return self._stream_chat_completion(request) @@ -110,7 +116,7 @@ def chat_completion( async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - async with SEMAPHORE: + def impl(): messages = chat_completion_request_to_messages(request) tokens = [] @@ -154,10 +160,16 @@ async def _nonstream_chat_completion( logprobs=logprobs if request.logprobs else None, ) + if self.config.create_distributed_process_group: + async with SEMAPHORE: + return impl() + else: + return impl() + async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - async with SEMAPHORE: + def impl(): messages = chat_completion_request_to_messages(request) yield ChatCompletionResponseStreamChunk( @@ -272,6 +284,14 @@ async def _stream_chat_completion( ) ) + if self.config.create_distributed_process_group: + async with SEMAPHORE: + for x in impl(): + yield x + else: + for x in impl(): + yield x + async def embeddings( self, model: str,