Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Await socket operations + some other minor cleanup (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Jul 31, 2024
1 parent 5362952 commit 7214fb8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"proecss as the model servinge engine.")
"process as the model serving engine.")

parser = AsyncEngineArgs.add_cli_args(parser)

Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
socket.connect(self.path)

# Ping RPC Server with request.
socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))
await socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))

# Await acknowledgement from RPCServer.
response = pickle.loads(await socket.recv())

if (not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR):
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
socket.close()
raise ValueError(error_message)

Expand Down Expand Up @@ -80,7 +80,7 @@ async def get_model_config(self) -> ModelConfig:
socket.connect(self.path)

# Ping RPCServer with GET_MODEL_CONFIG request.
socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))
await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))

# Await the MODEL_CONFIG from the Server.
model_config = pickle.loads(await socket.recv())
Expand Down Expand Up @@ -126,7 +126,7 @@ async def generate(
socket.connect(self.path)

# Send RPCGenerateRequest to the RPCServer.
socket.send_multipart([
await socket.send_multipart([
pickle.dumps(
RPCGenerateRequest(
inputs=inputs,
Expand Down
17 changes: 7 additions & 10 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

class RPCServer:

# TODO: check if opening all these sockets is an antipattern.
# Alternative, use a smaller number of sockets with conditioning on the
# data that is passed through the socket.
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
# Initialize engine first.
Expand All @@ -41,7 +38,7 @@ def cleanup(self):

async def _send_success_message(self, identity):
"""Send message to client indicating an action was successful."""
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])
Expand All @@ -50,20 +47,20 @@ async def get_model_config(self, identity):
"""Send the ModelConfig """
model_config = await self.engine.get_model_config()

self.socket.send_multipart(
await self.socket.send_multipart(
[identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)])

async def do_log_stats(self, identity):
await self.engine.do_log_stats()

self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])

async def is_server_ready(self, identity):
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])
Expand All @@ -73,7 +70,7 @@ async def abort(self, identity, request: RPCAbortRequest):
await self.engine.abort(request.request_id)

# Send confirmation to the client.
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])
Expand All @@ -86,14 +83,14 @@ async def generate(self, identity, generate_request: RPCGenerateRequest):
request_id=generate_request.request_id)

async for request_output in results_generator:
self.socket.send_multipart([
await self.socket.send_multipart([
identity,
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)
])

except Exception as e:
### Notify client of all failures
self.socket.send_multipart(
await self.socket.send_multipart(
[identity, pickle.dumps(e, pickle.HIGHEST_PROTOCOL)])

def _make_handler_coro(self, identity,
Expand Down
7 changes: 4 additions & 3 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,14 @@ def merge_async_iterators(
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()

finished = [False] * len(iterators)
producers = len(iterators)

async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())

Expand All @@ -320,13 +319,15 @@ async def producer(i: int, iterator: AsyncIterator[T]):
]

async def consumer():
remaining = producers
try:
while not all(finished) or not queue.empty():
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()

if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue

if isinstance(item, Exception):
Expand Down

0 comments on commit 7214fb8

Please sign in to comment.