Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Use random port for backend (#390)
Browse files Browse the repository at this point in the history
Picks an open port to use and boots both the client and server with it

---------

Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Jul 31, 2024
1 parent 1f33286 commit 5362952
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 24 deletions.
13 changes: 8 additions & 5 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION

TIMEOUT_KEEP_ALIVE = 5 # seconds
Expand Down Expand Up @@ -107,15 +107,18 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]:
else:
# remote backend
## First need to start the backend process
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER))
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()

## Then build the client for the backend process
# TODO: figure out a way around passing the tokenizer
backend = RPCClient(
tokenizer=AutoTokenizer.from_pretrained(args.model))
backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(
args.model),
port=port)
await backend.wait_for_server()

try:
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams

VLLM_RPC_PATH = "tcp://localhost:5570"
VLLM_RPC_SUCCESS_STR = "SUCCESS"


Expand Down
11 changes: 6 additions & 5 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import zmq.asyncio

from vllm.config import DecodingConfig, ModelConfig
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_PATH,
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
Expand All @@ -18,13 +18,14 @@
class RPCClient:

# TODO: check if opening all these sockets is an antipattern?
def __init__(self, tokenizer):
def __init__(self, tokenizer, port: int):
# ZMQ context.
self.context = zmq.asyncio.Context()

# TODO: do the tokenizer properly.
self.tokenizer = tokenizer
self.decoding_config = DecodingConfig()
self.path = f"tcp://localhost:{port}"

def close(self):
"""Destroy the ZeroMQ Context."""
Expand All @@ -36,7 +37,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(VLLM_RPC_PATH)
socket.connect(self.path)

# Ping RPC Server with request.
socket.send(pickle.dumps(request, pickle.HIGHEST_PROTOCOL))
Expand Down Expand Up @@ -76,7 +77,7 @@ async def get_model_config(self) -> ModelConfig:

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(VLLM_RPC_PATH)
socket.connect(self.path)

# Ping RPCServer with GET_MODEL_CONFIG request.
socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))
Expand Down Expand Up @@ -122,7 +123,7 @@ async def generate(
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(VLLM_RPC_PATH)
socket.connect(self.path)

# Send RPCGenerateRequest to the RPCServer.
socket.send_multipart([
Expand Down
20 changes: 9 additions & 11 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from typing_extensions import Never

from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_PATH, VLLM_RPC_SUCCESS_STR,
RPCAbortRequest, RPCGenerateRequest,
RPCUtilityRequest)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

Expand All @@ -23,7 +22,7 @@ class RPCServer:
# Alternative, use a smaller number of sockets with conditioning on the
# data that is passed through the socket.
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext):
usage_context: UsageContext, port: int):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
Expand All @@ -33,7 +32,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs,

# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
self.socket.bind(VLLM_RPC_PATH)
self.socket.bind(f"tcp://localhost:{port}")

def cleanup(self):
"""Cleanup all resources."""
Expand All @@ -51,10 +50,9 @@ async def get_model_config(self, identity):
"""Send the ModelConfig """
model_config = await self.engine.get_model_config()

self.socket.send_multipart([
identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)
])
self.socket.send_multipart(
[identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)])

async def do_log_stats(self, identity):
await self.engine.do_log_stats()
Expand Down Expand Up @@ -166,6 +164,6 @@ def signal_handler() -> None:


def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext):
server = RPCServer(async_engine_args, usage_context)
usage_context: UsageContext, port: int):
server = RPCServer(async_engine_args, usage_context, port)
asyncio.run(run_server(server))
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
Expand Down Expand Up @@ -142,6 +143,11 @@ def get_default_config_root():
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,

# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT':
lambda: int(os.getenv('VLLM_PORT', '5570')),

# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":
Expand Down
6 changes: 4 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"


def get_open_port() -> int:
port = envs.VLLM_PORT
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = envs.VLLM_PORT
if port is not None:
while True:
try:
Expand Down

0 comments on commit 5362952

Please sign in to comment.