From d944e3c9c7c9fbea5e4b8ff67ba8b283ba7c68ff Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 31 Jul 2024 11:16:40 -0700 Subject: [PATCH] Await socket operations + some other minor cleanup --- vllm/entrypoints/openai/cli_args.py | 2 +- vllm/entrypoints/openai/rpc/client.py | 8 ++++---- vllm/entrypoints/openai/rpc/server.py | 17 +++++++---------- vllm/utils.py | 7 ++++--- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e637e20e16f5a..1facedac72ca8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -138,7 +138,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--disable-frontend-multiprocessing", action="store_true", help="If specified, will run the OpenAI frontend server in the same " - "proecss as the model servinge engine.") + "process as the model serving engine.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 1e8a98d6418f7..ea50338c1f2e0 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -40,12 +40,12 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, socket.connect(self.path) # Ping RPC Server with request. - socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) + await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL)) # Await acknowledgement from RPCServer. response = pickle.loads(await socket.recv()) - if (not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR): + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: socket.close() raise ValueError(error_message) @@ -80,7 +80,7 @@ async def get_model_config(self) -> ModelConfig: socket.connect(self.path) # Ping RPCServer with GET_MODEL_CONFIG request. - socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) + await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) # Await the MODEL_CONFIG from the Server. model_config = pickle.loads(await socket.recv()) @@ -126,7 +126,7 @@ async def generate( socket.connect(self.path) # Send RPCGenerateRequest to the RPCServer. - socket.send_multipart([ + await socket.send_multipart([ pickle.dumps( RPCGenerateRequest( inputs=inputs, diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 6385eaa1b226d..17439d1bef961 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -18,9 +18,6 @@ class RPCServer: - # TODO: check if opening all these sockets is an antipattern. - # Alternative, use a smaller number of sockets with conditioning on the - # data that is passed through the socket. def __init__(self, async_engine_args: AsyncEngineArgs, usage_context: UsageContext, port: int): # Initialize engine first. @@ -41,7 +38,7 @@ def cleanup(self): async def _send_success_message(self, identity): """Send message to client indicating an action was successful.""" - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) @@ -50,20 +47,20 @@ async def get_model_config(self, identity): """Send the ModelConfig """ model_config = await self.engine.get_model_config() - self.socket.send_multipart( + await self.socket.send_multipart( [identity, pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) async def do_log_stats(self, identity): await self.engine.do_log_stats() - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) async def is_server_ready(self, identity): - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) @@ -73,7 +70,7 @@ async def abort(self, identity, request: RPCAbortRequest): await self.engine.abort(request.request_id) # Send confirmation to the client. - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) @@ -86,14 +83,14 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): request_id=generate_request.request_id) async for request_output in results_generator: - self.socket.send_multipart([ + await self.socket.send_multipart([ identity, pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL) ]) except Exception as e: ### Notify client of all failures - self.socket.send_multipart( + await self.socket.send_multipart( [identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)]) def _make_handler_coro(self, identity, diff --git a/vllm/utils.py b/vllm/utils.py index 59ebab1eb3809..b18c3f3e81e65 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -302,7 +302,7 @@ def merge_async_iterators( queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, Exception]] = asyncio.Queue() - finished = [False] * len(iterators) + producers = len(iterators) async def producer(i: int, iterator: AsyncIterator[T]): try: @@ -310,7 +310,6 @@ async def producer(i: int, iterator: AsyncIterator[T]): await queue.put((i, item)) except Exception as e: await queue.put(e) - finished[i] = True # Signal to the consumer that we've finished await queue.put(ProducerFinished()) @@ -320,13 +319,15 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): + remaining = producers try: - while not all(finished) or not queue.empty(): + while remaining or not queue.empty(): # we think there is a race condition here item = await queue.get() if isinstance(item, ProducerFinished): # Signal that a producer finished- not a real item + remaining -= 1 continue if isinstance(item, Exception):