Skip to content

Commit

Permalink
made vllm dependency optional (#17) and cleaned vllm files
Browse files Browse the repository at this point in the history
  • Loading branch information
benlebrun committed Jun 25, 2024
1 parent ada035b commit 9248307
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 328 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ update :
.PHONY : env
env : $(NAME).egg-info/
$(NAME).egg-info/ : setup.py
ifeq ($(shell uname -s),Darwin)
@$(INSTALL) -e ".[test]" && pre-commit install
else
@$(INSTALL) -e ".[test,vllm]" && pre-commit install
endif

## format : format code style.
.PHONY : format
Expand Down
227 changes: 118 additions & 109 deletions genparse/vllm_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import torch

import vllm
from typing import Optional, List, Union
from typing import Optional
import time
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import Counter
from vllm.usage.usage_lib import UsageContext
from vllm.outputs import EmbeddingRequestOutput, RequestOutput, RequestOutputFactory
from vllm.sequence import ExecuteModelRequest
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar
from vllm.core.scheduler import ScheduledSequenceGroup, Scheduler, SchedulerOutputs

try:
import vllm
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import Counter
from vllm.usage.usage_lib import UsageContext
from vllm.outputs import RequestOutputFactory

VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False


class LogitsSampler(torch.nn.Module):
Expand Down Expand Up @@ -49,104 +49,113 @@ def forward(
return grouped_logprobs, grouped_seq_ids


class pplLMEngine(vllm.LLMEngine):
def _process_model_outputs(
self,
output,
scheduled_seq_groups,
ignored_seq_groups,
seq_group_metadata_list,
):
"""
Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""
if VLLM_AVAILABLE:

now = time.time()

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output, seq_group_metadata_list
class pplLMEngine(vllm.LLMEngine):
def _process_model_outputs(
self,
output,
scheduled_seq_groups,
ignored_seq_groups,
seq_group_metadata_list,
):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(scheduled_seq_group.token_chunk_size)

self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)

# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()

# Create the outputs.
request_outputs = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs


class vllmpplLLM(vllm.LLM):
"""
Wrapper around VLLM to make it compatible with hfppl.
1. vllm sampler replaced with LogitsSampler
2. added next_token_logprobs, p_next methods
"""
"""
Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""

now = time.time()

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output, seq_group_metadata_list
):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(scheduled_seq_group.token_chunk_size)

self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)

# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()

# Create the outputs.
request_outputs = []
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
return request_outputs

class vllmpplLLM(vllm.LLM):
"""
Wrapper around VLLM to make it compatible with hfppl.
1. vllm sampler replaced with LogitsSampler
2. added next_token_logprobs, p_next methods
"""

def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = 'auto',
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = 'auto',
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if 'disable_log_stats' not in kwargs:
kwargs['disable_log_stats'] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = 'auto',
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = 'auto',
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
)
self.llm_engine = pplLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS
)
# sampler of the model
self.llm_engine.model_executor.driver_worker.model_runner.model.sampler = (
LogitsSampler()
)
self.request_counter = Counter()
self.eos_token_id = self.llm_engine._get_eos_token_id(lora_request=None)
) -> None:
if 'disable_log_stats' not in kwargs:
kwargs['disable_log_stats'] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.llm_engine = pplLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS
)
# sampler of the model
self.llm_engine.model_executor.driver_worker.model_runner.model.sampler = (
LogitsSampler()
)
self.request_counter = Counter()
self.eos_token_id = self.llm_engine._get_eos_token_id(lora_request=None)
else:

class vllmpplLLM:
def __init__(self, *args, **kwargs):
raise ImportError(
'No vllm functionality available since vllm is not installed. '
'You probably installed genparse without the optional vllm dependency.'
)
Loading

0 comments on commit 9248307

Please sign in to comment.